From 6b17d3fb9460c39c9054070404a63ef004c9909c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 08:15:25 +0200 Subject: [PATCH 1/5] Simplify boilerplate for P2P shuffles (#8174) --- distributed/shuffle/_core.py | 22 ++++++++++++++++++++++ distributed/shuffle/_rechunk.py | 13 ++++--------- distributed/shuffle/_shuffle.py | 18 ++++-------------- distributed/shuffle/tests/test_rechunk.py | 2 +- distributed/shuffle/tests/test_shuffle.py | 6 +++--- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index af7c83cafa..9f9763ab09 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -334,3 +334,25 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.run_id) + + +@contextlib.contextmanager +def handle_transfer_errors(id: ShuffleId) -> Iterator[None]: + try: + yield + except ShuffleClosedError: + raise Reschedule() + except Exception as e: + raise RuntimeError(f"P2P shuffling {id} failed during transfer phase") from e + + +@contextlib.contextmanager +def handle_unpack_errors(id: ShuffleId) -> Iterator[None]: + try: + yield + except Reschedule as e: + raise e + except ShuffleClosedError: + raise Reschedule() + except Exception as e: + raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 45ae05013e..dc8bd0371a 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -111,13 +111,14 @@ from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from distributed.core import PooledRPCCall -from distributed.exceptions import Reschedule from distributed.shuffle._core import ( NDIndex, ShuffleId, ShuffleRun, ShuffleSpec, get_worker_plugin, + handle_transfer_errors, + handle_unpack_errors, ) from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin @@ -143,27 +144,21 @@ def rechunk_transfer( new: ChunkedAxes, old: ChunkedAxes, ) -> int: - try: + with handle_transfer_errors(id): return get_worker_plugin().add_partition( input, partition_id=input_chunk, spec=ArrayRechunkSpec(id=id, new=new, old=old), ) - except Exception as e: - raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e def rechunk_unpack( id: ShuffleId, output_chunk: NDIndex, barrier_run_id: int ) -> np.ndarray: - try: + with handle_unpack_errors(id): return get_worker_plugin().get_output_partition( id, barrier_run_id, output_chunk ) - except Reschedule as e: - raise e - except Exception as e: - raise RuntimeError(f"rechunk_unpack failed during shuffle {id}") from e def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index a5aef7cd21..7d0b0a6a8f 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -17,7 +17,6 @@ from dask.typing import Key from distributed.core import PooledRPCCall -from distributed.exceptions import Reschedule from distributed.shuffle._arrow import ( check_dtype_support, check_minimal_arrow_version, @@ -32,8 +31,9 @@ ShuffleSpec, barrier_key, get_worker_plugin, + handle_transfer_errors, + handle_unpack_errors, ) -from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin @@ -59,7 +59,7 @@ def shuffle_transfer( meta: pd.DataFrame, parts_out: set[int], ) -> int: - try: + with handle_transfer_errors(id): return get_worker_plugin().add_partition( input, input_partition, @@ -71,25 +71,15 @@ def shuffle_transfer( parts_out=parts_out, ), ) - except ShuffleClosedError: - raise Reschedule() - except Exception as e: - raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e def shuffle_unpack( id: ShuffleId, output_partition: int, barrier_run_id: int ) -> pd.DataFrame: - try: + with handle_unpack_errors(id): return get_worker_plugin().get_output_partition( id, barrier_run_id, output_partition ) - except Reschedule as e: - raise e - except ShuffleClosedError: - raise Reschedule() - except Exception as e: - raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") from e def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int: diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index fcdee28f85..509025e767 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -236,7 +236,7 @@ async def test_rechunk_with_single_output_chunk_raises(c, s, *ws): assert x2.chunks == new # FIXME: distributed#7816 with raises_with_cause( - RuntimeError, "rechunk_transfer failed", RuntimeError, "Barrier task" + RuntimeError, "failed during transfer", RuntimeError, "Barrier task" ): await c.compute(x2) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index d211f55f92..453fc279f4 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -222,7 +222,7 @@ async def test_shuffle_with_array_conversion(c, s, a, b, lose_annotations, npart if npartitions == 1: # FIXME: distributed#7816 with raises_with_cause( - RuntimeError, "shuffle_transfer failed", RuntimeError, "Barrier task" + RuntimeError, "failed during transfer", RuntimeError, "Barrier task" ): await c.compute(out) else: @@ -285,7 +285,7 @@ async def test_bad_disk(c, s, a, b): while not b.plugins["shuffle"].shuffles: await asyncio.sleep(0.01) shutil.rmtree(b.local_directory) - with pytest.raises(RuntimeError, match=f"shuffle_transfer failed .* {shuffle_id}"): + with pytest.raises(RuntimeError, match=f"{shuffle_id} failed during transfer"): out = await c.compute(out) await c.close() @@ -2143,7 +2143,7 @@ def make_partition(i): ddf = dd.from_map(make_partition, range(50)) out = ddf.shuffle(on="a", shuffle="p2p", ignore_index=True) with raises_with_cause( - RuntimeError, "shuffle_transfer", ValueError, "could not convert" + RuntimeError, "failed during transfer", ValueError, "could not convert" ): await c.compute(out) From 4e90a0aa468a978045256c969a1db89ab60762c0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 13:17:27 +0200 Subject: [PATCH 2/5] Introduce `Client.register_plugin()` (#8169) Co-authored-by: Florian Jetter --- distributed/client.py | 236 ++++++++++++++---- distributed/diagnostics/plugin.py | 33 +-- .../tests/test_cluster_dump_plugin.py | 2 +- .../diagnostics/tests/test_nanny_plugin.py | 82 +++++- .../tests/test_scheduler_plugin.py | 68 ++++- .../diagnostics/tests/test_worker_plugin.py | 106 ++++++-- distributed/nanny.py | 5 +- distributed/scheduler.py | 34 ++- distributed/shuffle/_scheduler_plugin.py | 2 +- distributed/shuffle/tests/test_shuffle.py | 14 +- distributed/shuffle/tests/utils.py | 2 +- distributed/tests/test_chaos.py | 2 +- distributed/tests/test_client.py | 24 +- distributed/tests/test_nanny.py | 10 +- distributed/tests/test_scheduler.py | 4 +- distributed/tests/test_stress.py | 2 +- distributed/tests/test_worker.py | 26 +- distributed/worker.py | 6 +- 18 files changed, 516 insertions(+), 142 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index e9561d1985..0be88cf33d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -21,11 +21,11 @@ from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import asynccontextmanager, contextmanager, suppress from contextvars import ContextVar -from functools import partial +from functools import partial, singledispatchmethod from importlib.metadata import PackageNotFoundError, version from numbers import Number from queue import Queue as pyQueue -from typing import Any, ClassVar, Literal, NamedTuple, TypedDict +from typing import Any, Callable, ClassVar, Literal, NamedTuple, TypedDict, cast from packaging.version import parse as parse_version from tlz import first, groupby, merge, partition_all, valmap @@ -46,7 +46,7 @@ ) from dask.widgets import get_template -from distributed.core import ErrorMessage +from distributed.core import ErrorMessage, OKMessage from distributed.protocol.serialize import _is_dumpable from distributed.utils import Deadline, wait_for @@ -75,6 +75,7 @@ from distributed.diagnostics.plugin import ( ForwardLoggingPlugin, NannyPlugin, + SchedulerPlugin, SchedulerUploadFile, UploadFile, WorkerPlugin, @@ -3781,10 +3782,11 @@ def upload_file(self, filename, load: bool = True): async def _(): results = await asyncio.gather( - self.register_scheduler_plugin( + self.register_plugin( SchedulerUploadFile(filename, load=load), name=name ), - self.register_worker_plugin(UploadFile(filename, load=load), name=name), + # FIXME: Make scheduler plugin responsible for (de)registering worker plugin + self.register_plugin(UploadFile(filename, load=load), name=name), ) return results[1] # Results from workers upload @@ -4814,15 +4816,90 @@ async def _get_task_stream( else: return msgs - async def _register_scheduler_plugin(self, plugin, name, idempotent=False): + def register_plugin( + self, + plugin: NannyPlugin | SchedulerPlugin | WorkerPlugin, + name: str | None = None, + idempotent: bool = False, + ): + """Register a plugin. + + See https://distributed.readthedocs.io/en/latest/plugins.html + + Parameters + ---------- + plugin : + A nanny, scheduler, or worker plugin to register. + name : + Name for the plugin; if None, a name is taken from the + plugin instance or automatically generated if not present. + idempotent : + Do not re-register if a plugin of the given name already exists. + """ + if name is None: + name = _get_plugin_name(plugin) + assert name + + return self._register_plugin(plugin, name, idempotent) + + @singledispatchmethod + def _register_plugin( + self, + plugin: NannyPlugin | SchedulerPlugin | WorkerPlugin, + name: str, + idempotent: bool, + ): + raise TypeError( + "Registering duck-typed plugins is not allowed. Please inherit from " + "NannyPlugin, WorkerPlugin, or SchedulerPlugin to create a plugin." + ) + + @_register_plugin.register + def _(self, plugin: SchedulerPlugin, name: str, idempotent: bool): + return self.sync( + self._register_scheduler_plugin, + plugin=plugin, + name=name, + idempotent=idempotent, + ) + + @_register_plugin.register + def _( + self, plugin: NannyPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: + return self.sync( + self._register_nanny_plugin, + plugin=plugin, + name=name, + idempotent=idempotent, + ) + + @_register_plugin.register + def _(self, plugin: WorkerPlugin, name: str, idempotent: bool): + return self.sync( + self._register_worker_plugin, + plugin=plugin, + name=name, + idempotent=idempotent, + ) + + async def _register_scheduler_plugin( + self, plugin: SchedulerPlugin, name: str, idempotent: bool + ): return await self.scheduler.register_scheduler_plugin( plugin=dumps(plugin), name=name, idempotent=idempotent, ) - def register_scheduler_plugin(self, plugin, name=None, idempotent=False): - """Register a scheduler plugin. + def register_scheduler_plugin( + self, plugin: SchedulerPlugin, name: str | None = None, idempotent: bool = False + ): + """ + Register a scheduler plugin. + + .. deprecated:: 2023.9.2 + Use :meth:`Client.register_plugin` instead. See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins @@ -4836,15 +4913,13 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False): idempotent : bool Do not re-register if a plugin of the given name already exists. """ - if name is None: - name = _get_plugin_name(plugin) - - return self.sync( - self._register_scheduler_plugin, - plugin=plugin, - name=name, - idempotent=idempotent, + warnings.warn( + "`Client.register_scheduler_plugin` has been deprecated; " + "please `Client.register_plugin` instead", + DeprecationWarning, + stacklevel=2, ) + return cast(OKMessage, self.register_plugin(plugin, name, idempotent)) async def _unregister_scheduler_plugin(self, name): return await self.scheduler.unregister_scheduler_plugin(name=name) @@ -4875,7 +4950,7 @@ def unregister_scheduler_plugin(self, name): ... pass >>> plugin = MyPlugin(1, 2, 3) - >>> client.register_scheduler_plugin(plugin, name='foo') + >>> client.register_plugin(plugin, name='foo') >>> client.unregister_scheduler_plugin(name='foo') See Also @@ -4902,41 +4977,50 @@ def register_worker_callbacks(self, setup=None): setup : callable(dask_worker: Worker) -> None Function to register and run on all workers """ - return self.register_worker_plugin(_WorkerSetupPlugin(setup)) + return self.register_plugin(_WorkerSetupPlugin(setup)) - async def _register_worker_plugin(self, plugin=None, name=None, nanny=None): - if nanny or nanny is None and isinstance(plugin, NannyPlugin): - if not isinstance(plugin, NannyPlugin): - warnings.warn( - "Registering duck-typed plugins has been deprecated. " - "Please make sure your plugin subclasses `NannyPlugin`.", - DeprecationWarning, - stacklevel=2, - ) - method = self.scheduler.register_nanny_plugin - else: - if not isinstance(plugin, WorkerPlugin): - warnings.warn( - "Registering duck-typed plugins has been deprecated. " - "Please make sure your plugin subclasses `WorkerPlugin`.", - DeprecationWarning, - stacklevel=2, + async def _register_worker_plugin( + self, plugin: WorkerPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: + responses = await self.scheduler.register_worker_plugin( + plugin=dumps(plugin), name=name, idempotent=idempotent + ) + for response in responses.values(): + if response["status"] == "error": + _, exc, tb = clean_exception( + response["exception"], response["traceback"] ) - method = self.scheduler.register_worker_plugin + assert exc + raise exc.with_traceback(tb) + return cast(dict[str, OKMessage], responses) - responses = await method(plugin=dumps(plugin), name=name) + async def _register_nanny_plugin( + self, plugin: NannyPlugin, name: str, idempotent: bool + ) -> dict[str, OKMessage]: + responses = await self.scheduler.register_nanny_plugin( + plugin=dumps(plugin), name=name, idempotent=idempotent + ) for response in responses.values(): if response["status"] == "error": _, exc, tb = clean_exception( response["exception"], response["traceback"] ) + assert exc raise exc.with_traceback(tb) - return responses + return cast(dict[str, OKMessage], responses) - def register_worker_plugin(self, plugin=None, name=None, nanny=None): + def register_worker_plugin( + self, + plugin: NannyPlugin | WorkerPlugin, + name: str | None = None, + nanny: bool | None = None, + ): """ Registers a lifecycle worker plugin for all current and future workers. + .. deprecated:: 2023.9.2 + Use :meth:`Client.register_plugin` instead. + This registers a new object to handle setup, task state transitions and teardown for workers in this cluster. The plugin will instantiate itself on all currently connected workers. It will also be run on any @@ -4982,11 +5066,11 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None): ... pass >>> plugin = MyPlugin(1, 2, 3) - >>> client.register_worker_plugin(plugin) + >>> client.register_plugin(plugin) You can get access to the plugin with the ``get_worker`` function - >>> client.register_worker_plugin(other_plugin, name='my-plugin') + >>> client.register_plugin(other_plugin, name='my-plugin') >>> def f(): ... worker = get_worker() ... plugin = worker.plugins['my-plugin'] @@ -4999,14 +5083,70 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None): distributed.WorkerPlugin unregister_worker_plugin """ + warnings.warn( + "`Client.register_worker_plugin` has been deprecated; " + "please use `Client.register_plugin` instead", + DeprecationWarning, + stacklevel=2, + ) if name is None: name = _get_plugin_name(plugin) assert name - return self.sync( - self._register_worker_plugin, plugin=plugin, name=name, nanny=nanny - ) + method: Callable + if isinstance(plugin, WorkerPlugin): + method = self._register_worker_plugin + if nanny is True: + warnings.warn( + "Registering a `WorkerPlugin` as a nanny plugin is not " + "allowed, registering as a worker plugin instead. " + "To register as a nanny plugin, inherit from `NannyPlugin`.", + UserWarning, + stacklevel=2, + ) + elif isinstance(plugin, NannyPlugin): + method = self._register_nanny_plugin + if nanny is False: + warnings.warn( + "Registering a `NannyPlugin` as a worker plugin is not " + "allowed, registering as a nanny plugin instead. " + "To register as a worker plugin, inherit from `WorkerPlugin`.", + UserWarning, + stacklevel=2, + ) + elif isinstance(plugin, SchedulerPlugin): # type: ignore[unreachable] + if nanny: + warnings.warn( + "Registering a `SchedulerPlugin` as a nanny plugin is not " + "allowed, registering as a scheduler plugin instead. " + "To register as a nanny plugin, inherit from `NannyPlugin`.", + UserWarning, + stacklevel=2, + ) + else: + warnings.warn( + "Registering a `SchedulerPlugin` as a worker plugin is not " + "allowed, registering as a scheduler plugin instead. " + "To register as a worker plugin, inherit from `WorkerPlugin`.", + UserWarning, + stacklevel=2, + ) + method = self._register_scheduler_plugin + else: + warnings.warn( + "Registering duck-typed plugins has been deprecated. " + "Please make sure your plugin inherits from `NannyPlugin` " + "or `WorkerPlugin`.", + DeprecationWarning, + stacklevel=2, + ) + if nanny is True: + method = self._register_nanny_plugin + else: + method = self._register_worker_plugin + + return self.sync(method, plugin=plugin, name=name, idempotent=False) async def _unregister_worker_plugin(self, name, nanny=None): if nanny: @@ -5030,7 +5170,7 @@ def unregister_worker_plugin(self, name, nanny=None): Parameters ---------- name : str - Name of the plugin to unregister. See the :meth:`Client.register_worker_plugin` + Name of the plugin to unregister. See the :meth:`Client.register_plugin` docstring for more information. Examples @@ -5048,12 +5188,12 @@ def unregister_worker_plugin(self, name, nanny=None): ... pass >>> plugin = MyPlugin(1, 2, 3) - >>> client.register_worker_plugin(plugin, name='foo') + >>> client.register_plugin(plugin, name='foo') >>> client.unregister_worker_plugin(name='foo') See Also -------- - register_worker_plugin + register_plugin """ return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny) @@ -5190,7 +5330,7 @@ class : logging.StreamHandler # removed and torn down (see distributed.worker.Worker.plugin_add()), so # this is effectively idempotent, i.e., forwarding the same logger twice # won't cause every LogRecord to be forwarded twice - return self.register_worker_plugin( + return self.register_plugin( ForwardLoggingPlugin(logger_name, level, topic), plugin_name ) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 02ae91a2ae..2e36857cf4 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -34,9 +34,9 @@ class SchedulerPlugin: To implement a plugin: - 1. subclass this class + 1. inherit from this class 2. override some of its methods - 3. add the plugin to the scheduler with ``Scheduler.add_plugin(myplugin)``. + 3. register the plugin using :meth:`Client.register_plugin`. Examples -------- @@ -203,9 +203,11 @@ class WorkerPlugin: an event happens, the corresponding method on this class will be called. Note that the user code always runs within the Worker's main thread. - To implement a plugin implement some of the methods of this class and register - the plugin to your client in order to have it attached to every existing and - future workers with ``Client.register_worker_plugin``. + To implement a plugin: + + 1. inherit from this class + 2. override some of its methods + 3. register the plugin using :meth:`Client.register_plugin`. Examples -------- @@ -227,7 +229,7 @@ class WorkerPlugin: >>> import logging >>> plugin = ErrorLogger(logging) - >>> client.register_worker_plugin(plugin) # doctest: +SKIP + >>> client.register_plugin(plugin) # doctest: +SKIP """ def setup(self, worker): @@ -275,10 +277,11 @@ class NannyPlugin: to run code before the worker is started, or to restart the worker if necessary. - To implement a plugin implement some of the methods of this class and register - the plugin to your client in order to have it attached to every existing and - future nanny by passing ``nanny=True`` to - :meth:`Client.register_worker_plugin`. + To implement a plugin: + + 1. inherit from this class + 2. override some of its methods + 3. register the plugin using :meth:`Client.register_plugin`. The ``restart`` attribute is used to control whether or not a running ``Worker`` needs to be restarted when registering the plugin. @@ -474,7 +477,7 @@ class CondaInstall(PackageInstall): >>> from dask.distributed import CondaInstall >>> plugin = CondaInstall(packages=["scikit-learn"], conda_options=["--update-deps"]) - >>> client.register_worker_plugin(plugin) + >>> client.register_plugin(plugin) See Also -------- @@ -550,7 +553,7 @@ class PipInstall(PackageInstall): >>> from dask.distributed import PipInstall >>> plugin = PipInstall(packages=["scikit-learn"], pip_options=["--upgrade"]) - >>> client.register_worker_plugin(plugin) + >>> client.register_plugin(plugin) See Also -------- @@ -598,7 +601,7 @@ class UploadFile(WorkerPlugin): -------- >>> from distributed.diagnostics.plugin import UploadFile - >>> client.register_worker_plugin(UploadFile("/path/to/file.py")) # doctest: +SKIP + >>> client.register_plugin(UploadFile("/path/to/file.py")) # doctest: +SKIP """ name = "upload_file" @@ -722,7 +725,7 @@ class UploadDirectory(NannyPlugin): Examples -------- >>> from distributed.diagnostics.plugin import UploadDirectory - >>> client.register_worker_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP + >>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP """ def __init__( @@ -859,7 +862,7 @@ class ForwardOutput(WorkerPlugin): >>> from dask.distributed import ForwardOutput >>> plugin = ForwardOutput() - >>> client.register_worker_plugin(plugin) + >>> client.register_plugin(plugin) """ def setup(self, worker): diff --git a/distributed/diagnostics/tests/test_cluster_dump_plugin.py b/distributed/diagnostics/tests/test_cluster_dump_plugin.py index 8c99a567ac..7ceb7b7181 100644 --- a/distributed/diagnostics/tests/test_cluster_dump_plugin.py +++ b/distributed/diagnostics/tests/test_cluster_dump_plugin.py @@ -8,7 +8,7 @@ @gen_cluster(client=True) async def test_cluster_dump_plugin(c, s, *workers, tmp_path): dump_file = tmp_path / "cluster_dump.msgpack.gz" - await c.register_scheduler_plugin(ClusterDump(str(dump_file)), name="cluster-dump") + await c.register_plugin(ClusterDump(str(dump_file)), name="cluster-dump") plugin = s.plugins["cluster-dump"] assert plugin.scheduler is s diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index eedf63c261..cded39d2cd 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -2,20 +2,96 @@ import pytest -from distributed import Nanny +from distributed import Nanny, NannyPlugin from distributed.utils_test import gen_cluster @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) -async def test_duck_typed_nanny_plugin_is_deprecated(c, s, a): - class DuckPlugin: +async def test_register_worker_plugin_is_deprecated(c, s, a): + class DuckPlugin(NannyPlugin): def setup(self, nanny): + nanny.foo = 123 + + def teardown(self, nanny): pass + n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") + with pytest.warns(DeprecationWarning, match="register_worker_plugin.*deprecated"): + await c.register_worker_plugin(DuckPlugin()) + assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_worker_plugin_typing_over_nanny_keyword(c, s, a): + class DuckPlugin(NannyPlugin): + def setup(self, nanny): + nanny.foo = 123 + def teardown(self, nanny): pass n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") + with pytest.warns(UserWarning, match="`NannyPlugin` as a worker plugin"): + await c.register_worker_plugin(DuckPlugin(), nanny=False) + assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_duck_typed_register_nanny_plugin_is_deprecated(c, s, a): + class DuckPlugin: + def setup(self, nanny): + nanny.foo = 123 + + def teardown(self, nanny): + pass + + n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") with pytest.warns(DeprecationWarning, match="duck-typed.*NannyPlugin"): await c.register_worker_plugin(DuckPlugin(), nanny=True) assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_idempotent_plugins(c, s, a): + class IdempotentPlugin(NannyPlugin): + def __init__(self, instance=None): + self.name = "idempotentplugin" + self.instance = instance + + def setup(self, nanny): + if self.instance != "first": + raise RuntimeError( + "Only the first plugin should be started when idempotent is set" + ) + + first = IdempotentPlugin(instance="first") + await c.register_plugin(first, idempotent=True) + assert "idempotentplugin" in a.plugins + + second = IdempotentPlugin(instance="second") + await c.register_plugin(second, idempotent=True) + assert "idempotentplugin" in a.plugins + assert a.plugins["idempotentplugin"].instance == "first" + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_non_idempotent_plugins(c, s, a): + class NonIdempotentPlugin(NannyPlugin): + def __init__(self, instance=None): + self.name = "nonidempotentplugin" + self.instance = instance + + first = NonIdempotentPlugin(instance="first") + await c.register_plugin(first, idempotent=False) + assert "nonidempotentplugin" in a.plugins + + second = NonIdempotentPlugin(instance="second") + await c.register_plugin(second, idempotent=False) + assert "nonidempotentplugin" in a.plugins + assert a.plugins["nonidempotentplugin"].instance == "second" diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 01e651bf8b..49c37149ca 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -4,7 +4,7 @@ import pytest -from distributed import Scheduler, SchedulerPlugin, Worker, get_worker +from distributed import Nanny, Scheduler, SchedulerPlugin, Worker, get_worker from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -358,7 +358,7 @@ async def close(self): @gen_cluster(client=True) -async def test_register_scheduler_plugin(c, s, a, b): +async def test_register_plugin(c, s, a, b): class Dummy1(SchedulerPlugin): name = "Dummy1" @@ -366,11 +366,11 @@ def start(self, scheduler): scheduler.foo = "bar" assert not hasattr(s, "foo") - await c.register_scheduler_plugin(Dummy1()) + await c.register_plugin(Dummy1()) assert s.foo == "bar" with pytest.warns(UserWarning) as w: - await c.register_scheduler_plugin(Dummy1()) + await c.register_plugin(Dummy1()) assert "Scheduler already contains" in w[0].message.args[0] class Dummy2(SchedulerPlugin): @@ -381,20 +381,36 @@ def start(self, scheduler): n_plugins = len(s.plugins) with pytest.raises(RuntimeError, match="raising in start method"): - await c.register_scheduler_plugin(Dummy2()) + await c.register_plugin(Dummy2()) # total number of plugins should be unchanged assert n_plugins == len(s.plugins) +@gen_cluster(client=True) +async def test_register_scheduler_plugin_deprecated(c, s, a, b): + class Dummy(SchedulerPlugin): + name = "Dummy" + + def start(self, scheduler): + scheduler.foo = "bar" + + assert not hasattr(s, "foo") + with pytest.warns( + DeprecationWarning, match="register_scheduler_plugin.*deprecated" + ): + await c.register_scheduler_plugin(Dummy()) + assert s.foo == "bar" + + @gen_cluster(client=True, config={"distributed.scheduler.pickle": False}) -async def test_register_scheduler_plugin_pickle_disabled(c, s, a, b): +async def test_register_plugin_pickle_disabled(c, s, a, b): class Dummy1(SchedulerPlugin): def start(self, scheduler): scheduler.foo = "bar" n_plugins = len(s.plugins) with pytest.raises(ValueError) as excinfo: - await c.register_scheduler_plugin(Dummy1()) + await c.register_plugin(Dummy1()) msg = str(excinfo.value) assert "disallowed from deserializing" in msg @@ -426,7 +442,7 @@ class Plugin(SchedulerPlugin): name = "plugin" assert "plugin" not in s.plugins - await c.register_scheduler_plugin(Plugin()) + await c.register_plugin(Plugin()) assert "plugin" in s.plugins await c.unregister_scheduler_plugin("plugin") @@ -446,7 +462,7 @@ async def start(self, scheduler: Scheduler) -> None: def log_event(self, name, msg): self.scheduler._recorded_events.append((name, msg)) - await c.register_scheduler_plugin(EventPlugin()) + await c.register_plugin(EventPlugin()) def f(): get_worker().log_event("foo", 123) @@ -591,3 +607,37 @@ def update_graph( # type: ignore with dask.annotate(global_annot=24): await c.compute(f4) assert plugin.success + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_scheduler_plugin_in_register_worker_plugin_overrides(c, s, a): + class DuckPlugin(SchedulerPlugin): + def start(self, scheduler): + scheduler.foo = 123 + + def stop(self, scheduler): + pass + + n_existing_plugins = len(s.plugins) + assert not hasattr(s, "foo") + with pytest.warns(UserWarning, match="`SchedulerPlugin` as a worker plugin"): + await c.register_worker_plugin(DuckPlugin(), nanny=False) + assert len(s.plugins) == n_existing_plugins + 1 + assert s.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_scheduler_plugin_in_register_worker_plugin_overrides_nanny(c, s, a): + class DuckPlugin(SchedulerPlugin): + def start(self, scheduler): + scheduler.foo = 123 + + def stop(self, scheduler): + pass + + n_existing_plugins = len(s.plugins) + assert not hasattr(s, "foo") + with pytest.warns(UserWarning, match="`SchedulerPlugin` as a nanny plugin"): + await c.register_worker_plugin(DuckPlugin(), nanny=True) + assert len(s.plugins) == n_existing_plugins + 1 + assert s.foo == 123 diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 4d55041aca..64fffd4249 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -42,7 +42,7 @@ def transition(self, key, start, finish, **kwargs): @gen_cluster(client=True, nthreads=[]) async def test_create_with_client(c, s): - await c.register_worker_plugin(MyPlugin(123)) + await c.register_plugin(MyPlugin(123)) async with Worker(s.address) as worker: assert worker._my_plugin_status == "setup" @@ -55,8 +55,8 @@ async def test_create_with_client(c, s): async def test_remove_with_client(c, s): existing_plugins = s.worker_plugins.copy() n_existing_plugins = len(existing_plugins) - await c.register_worker_plugin(MyPlugin(123), name="foo") - await c.register_worker_plugin(MyPlugin(546), name="bar") + await c.register_plugin(MyPlugin(123), name="foo") + await c.register_plugin(MyPlugin(546), name="bar") async with Worker(s.address) as worker: # remove the 'foo' plugin @@ -80,7 +80,7 @@ async def test_remove_with_client(c, s): @gen_cluster(client=True, nthreads=[]) async def test_remove_with_client_raises(c, s): - await c.register_worker_plugin(MyPlugin(123), name="foo") + await c.register_plugin(MyPlugin(123), name="foo") async with Worker(s.address): with pytest.raises(ValueError, match="bar"): @@ -109,7 +109,7 @@ async def test_normal_task_transitions_called(c, s, w): plugin = MyPlugin(1, expected_notifications=expected_notifications) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) await c.submit(lambda x: x, 1, key="task") await async_poll_for(lambda: not w.state.tasks, timeout=10) @@ -133,7 +133,7 @@ def failing(x): plugin = MyPlugin(1, expected_notifications=expected_notifications) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) with pytest.raises(CustomError): await c.submit(failing, 1, key="task") @@ -155,7 +155,7 @@ async def test_superseding_task_transitions_called(c, s, w): plugin = MyPlugin(1, expected_notifications=expected_notifications) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) await c.submit(lambda x: x, 1, key="task", resources={"X": 1}) await async_poll_for(lambda: not w.state.tasks, timeout=10) @@ -181,7 +181,7 @@ async def test_dependent_tasks(c, s, w): plugin = MyPlugin(1, expected_notifications=expected_notifications) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) await c.get(dsk, "task", sync=False) await async_poll_for(lambda: not w.state.tasks, timeout=10) @@ -191,7 +191,7 @@ async def test_empty_plugin(c, s, w): class EmptyPlugin(WorkerPlugin): pass - await c.register_worker_plugin(EmptyPlugin()) + await c.register_plugin(EmptyPlugin()) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) @@ -200,7 +200,7 @@ class MyCustomPlugin(WorkerPlugin): pass n_existing_plugins = len(w.plugins) - await c.register_worker_plugin(MyCustomPlugin()) + await c.register_plugin(MyCustomPlugin()) assert len(w.plugins) == n_existing_plugins + 1 assert any(name.startswith("MyCustomPlugin-") for name in w.plugins) @@ -215,7 +215,7 @@ class Dummy(WorkerPlugin): pass with warnings.catch_warnings(record=True) as record: - await c.register_worker_plugin(Dummy()) + await c.register_plugin(Dummy()) assert await c.submit(inc, 1, key="x") == 2 while "x" in a.state.tasks: await asyncio.sleep(0.01) @@ -238,7 +238,7 @@ def transition(self, *args, **kwargs): def teardown(self, worker): del self.worker.foo - await c.register_worker_plugin(MyCustomPlugin()) + await c.register_plugin(MyCustomPlugin()) assert w.foo == 0 @@ -261,7 +261,7 @@ def transition(self, *args, **kwargs): def teardown(self, worker): del self.worker.bar - await c.register_worker_plugin(MyCustomPlugin()) + await c.register_plugin(MyCustomPlugin()) assert not hasattr(w, "foo") assert w.bar == 0 @@ -271,15 +271,91 @@ def teardown(self, worker): @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_duck_typed_worker_plugin_is_deprecated(c, s, a): - class DuckPlugin: +async def test_register_worker_plugin_is_deprecated(c, s, a): + class DuckPlugin(WorkerPlugin): + def setup(self, worker): + worker.foo = 123 + + def teardown(self, worker): + pass + + n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") + with pytest.warns(DeprecationWarning, match="register_worker_plugin.*deprecated"): + await c.register_worker_plugin(DuckPlugin()) + assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_worker_plugin_typing_over_nanny_keyword(c, s, a): + class DuckPlugin(WorkerPlugin): def setup(self, worker): + worker.foo = 123 + + def teardown(self, worker): pass + n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") + with pytest.warns(UserWarning, match="`WorkerPlugin` as a nanny plugin"): + await c.register_worker_plugin(DuckPlugin(), nanny=True) + assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_duck_typed_register_worker_plugin_is_deprecated(c, s, a): + class DuckPlugin: + def setup(self, worker): + worker.foo = 123 + def teardown(self, worker): pass n_existing_plugins = len(a.plugins) + assert not hasattr(a, "foo") with pytest.warns(DeprecationWarning, match="duck-typed.*WorkerPlugin"): await c.register_worker_plugin(DuckPlugin()) assert len(a.plugins) == n_existing_plugins + 1 + assert a.foo == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_idempotent_plugins(c, s, a): + class IdempotentPlugin(WorkerPlugin): + def __init__(self, instance=None): + self.name = "idempotentplugin" + self.instance = instance + + def setup(self, worker): + if self.instance != "first": + raise RuntimeError( + "Only the first plugin should be started when idempotent is set" + ) + + first = IdempotentPlugin(instance="first") + await c.register_plugin(first, idempotent=True) + assert "idempotentplugin" in a.plugins + + second = IdempotentPlugin(instance="second") + await c.register_plugin(second, idempotent=True) + assert "idempotentplugin" in a.plugins + assert a.plugins["idempotentplugin"].instance == "first" + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_non_idempotent_plugins(c, s, a): + class NonIdempotentPlugin(WorkerPlugin): + def __init__(self, instance=None): + self.name = "nonidempotentplugin" + self.instance = instance + + first = NonIdempotentPlugin(instance="first") + await c.register_plugin(first, idempotent=False) + assert "nonidempotentplugin" in a.plugins + + second = NonIdempotentPlugin(instance="second") + await c.register_plugin(second, idempotent=False) + assert "nonidempotentplugin" in a.plugins + assert a.plugins["nonidempotentplugin"].instance == "second" diff --git a/distributed/nanny.py b/distributed/nanny.py index 386e893988..7747adc863 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -461,7 +461,7 @@ async def plugin_add( if not isinstance(plugin, NannyPlugin): warnings.warn( "Registering duck-typed plugins has been deprecated. " - "Please make sure your plugin subclasses `NannyPlugin`.", + "Please make sure your plugin inherits from `NannyPlugin`.", DeprecationWarning, stacklevel=2, ) @@ -469,7 +469,6 @@ async def plugin_add( if name is None: name = _get_plugin_name(plugin) - assert name self.plugins[name] = plugin @@ -488,7 +487,7 @@ async def plugin_add( return {"status": "OK"} @log_errors - async def plugin_remove(self, name=None): + async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: logger.info(f"Removing Nanny plugin {name}") try: plugin = self.plugins.pop(name) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cc798529e5..e29983abbf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -81,7 +81,15 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback -from distributed.core import Status, clean_exception, error_message, rpc, send_recv +from distributed.core import ( + ErrorMessage, + OKMessage, + Status, + clean_exception, + error_message, + rpc, + send_recv, +) from distributed.diagnostics.memory_sampler import MemorySamplerExtension from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension @@ -7483,9 +7491,14 @@ def stop_task_metadata(self, name: str | None = None) -> dict: self.remove_plugin(name=plugin.name) return {"metadata": plugin.metadata, "state": plugin.state} - async def register_worker_plugin(self, comm, plugin, name=None): + async def register_worker_plugin( + self, plugin: bytes, name: str, idempotent: bool = False + ) -> dict[str, OKMessage]: """Registers a worker plugin on all running and future workers""" logger.info("Registering Worker plugin %s", name) + if name in self.worker_plugins and idempotent: + return {} + self.worker_plugins[name] = plugin responses = await self.broadcast( @@ -7493,7 +7506,9 @@ async def register_worker_plugin(self, comm, plugin, name=None): ) return responses - async def unregister_worker_plugin(self, comm, name): + async def unregister_worker_plugin( + self, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.worker_plugins.pop(name) @@ -7503,9 +7518,14 @@ async def unregister_worker_plugin(self, comm, name): responses = await self.broadcast(msg=dict(op="plugin-remove", name=name)) return responses - async def register_nanny_plugin(self, comm, plugin, name): - """Registers a setup function, and call it on every worker""" + async def register_nanny_plugin( + self, plugin: bytes, name: str, idempotent: bool = False + ) -> dict[str, OKMessage]: + """Registers a nanny plugin on all running and future nannies""" logger.info("Registering Nanny plugin %s", name) + if name in self.nanny_plugins and idempotent: + return {} + self.nanny_plugins[name] = plugin async with self._starting_nannies_cond: if self._starting_nannies: @@ -7519,7 +7539,9 @@ async def register_nanny_plugin(self, comm, plugin, name): ) return responses - async def unregister_nanny_plugin(self, comm, name): + async def unregister_nanny_plugin( + self, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.nanny_plugins.pop(name) diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index ce69560cee..a51af3de29 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -66,7 +66,7 @@ def __init__(self, scheduler: Scheduler): async def start(self, scheduler: Scheduler) -> None: worker_plugin = ShuffleWorkerPlugin() await self.scheduler.register_worker_plugin( - None, dumps(worker_plugin), name="shuffle" + dumps(worker_plugin), name="shuffle" ) def shuffle_ids(self) -> set[ShuffleId]: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 453fc279f4..68f97a2558 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -421,7 +421,7 @@ async def _get_or_create_shuffle(self, *args, **kwargs): config={"distributed.scheduler.allowed-failures": 0}, ) async def test_get_or_create_from_dangling_transfer(c, s, a, b): - await c.register_worker_plugin(BlockedGetOrCreateWorkerPlugin(), name="shuffle") + await c.register_plugin(BlockedGetOrCreateWorkerPlugin(), name="shuffle") df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -1786,9 +1786,7 @@ async def shuffle_receive(self, *args: Any, **kwargs: Any) -> None: @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): - await c.register_worker_plugin( - BlockedShuffleReceiveShuffleWorkerPlugin(), name="shuffle" - ) + await c.register_plugin(BlockedShuffleReceiveShuffleWorkerPlugin(), name="shuffle") df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1839,7 +1837,7 @@ async def _barrier(self, *args: Any, **kwargs: Any) -> int: @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): - await c.register_worker_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle") + await c.register_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle") df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1891,7 +1889,7 @@ async def test_shuffle_run_consistency(c, s, a): The P2P implementation relies on the correctness of this behavior, but it is an implementation detail that users should not rely upon. """ - await c.register_worker_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle") + await c.register_plugin(BlockedBarrierShuffleWorkerPlugin(), name="shuffle") worker_plugin = a.plugins["shuffle"] scheduler_ext = s.plugins["shuffle"] @@ -1998,9 +1996,7 @@ def shuffle_fail(self, *args: Any, **kwargs: Any) -> None: @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_replace_stale_shuffle(c, s, a, b): - await c.register_worker_plugin( - BlockedShuffleAccessAndFailWorkerPlugin(), name="shuffle" - ) + await c.register_plugin(BlockedShuffleAccessAndFailWorkerPlugin(), name="shuffle") ext_A = a.plugins["shuffle"] ext_B = b.plugins["shuffle"] diff --git a/distributed/shuffle/tests/utils.py b/distributed/shuffle/tests/utils.py index ef9d5f6c2d..2a3a7844c9 100644 --- a/distributed/shuffle/tests/utils.py +++ b/distributed/shuffle/tests/utils.py @@ -96,4 +96,4 @@ async def invoke_annotation_chaos(rate: float, client: Client) -> None: if not rate: return plugin = ShuffleAnnotationChaosPlugin(rate) - await client.register_scheduler_plugin(plugin) + await client.register_plugin(plugin) diff --git a/distributed/tests/test_chaos.py b/distributed/tests/test_chaos.py index 2c341744cd..2396f165fb 100644 --- a/distributed/tests/test_chaos.py +++ b/distributed/tests/test_chaos.py @@ -21,7 +21,7 @@ async def test_KillWorker(c, s, w, mode): plugin = KillWorker(delay="1ms", mode=mode) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) while s.workers: await asyncio.sleep(0.001) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 99ae9efbc8..8822d39c42 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6907,7 +6907,7 @@ def test_futures_in_subgraphs(loop_in_thread): @gen_cluster(client=True) async def test_get_task_metadata(c, s, a, b): # Populate task metadata - await c.register_worker_plugin(TaskStateMetadataPlugin()) + await c.register_plugin(TaskStateMetadataPlugin()) async with get_task_metadata() as tasks: f = c.submit(slowinc, 1) @@ -6927,7 +6927,7 @@ async def test_get_task_metadata(c, s, a, b): @gen_cluster(client=True) async def test_get_task_metadata_multiple(c, s, a, b): # Populate task metadata - await c.register_worker_plugin(TaskStateMetadataPlugin()) + await c.register_plugin(TaskStateMetadataPlugin()) # Ensure that get_task_metadata only collects metadata for # tasks which are submitted and completed within its context @@ -6958,7 +6958,7 @@ def setup(self, worker=None): raise ValueError("Setup failed") with pytest.raises(ValueError, match="Setup failed"): - await c.register_worker_plugin(MyPlugin()) + await c.register_plugin(MyPlugin()) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -7594,7 +7594,7 @@ async def test_upload_directory(c, s, a, b, tmp_path): f.write("from foo import x") plugin = UploadDirectory(tmp_path, restart=True, update_path=True) - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) [name] = a.plugins assert os.path.split(tmp_path)[-1] in name @@ -7616,6 +7616,22 @@ def f(): assert files_start == files_end # no change +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_duck_typed_register_plugin_raises(c, s, a): + class DuckPlugin: + def setup(self, worker): + pass + + def teardown(self, worker): + pass + + n_existing_plugins = len(a.plugins) + + with pytest.raises(TypeError, match="duck-typed.*inherit from.*Plugin"): + await c.register_plugin(DuckPlugin()) + assert len(a.plugins) == n_existing_plugins + + @gen_cluster(client=True) async def test_exception_text(c, s, a, b): def bad(x): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index b0c3fafc9c..f0e824d3ed 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -590,7 +590,7 @@ async def test_failure_during_worker_initialization(s): async def test_environ_plugin(c, s, a, b): from dask.distributed import Environ - await c.register_worker_plugin(Environ({"ABC": 123})) + await c.register_plugin(Environ({"ABC": 123})) async with Nanny(s.address, name="new") as n: results = await c.run(os.getenv, "ABC") @@ -822,7 +822,7 @@ class C: async def test_nanny_plugin_simple(c, s, a): """A plugin should be registered to already existing workers but also to new ones.""" plugin = DummyNannyPlugin("foo") - await c.register_worker_plugin(plugin) + await c.register_plugin(plugin) assert a._plugin_registered async with Nanny(s.address) as n: assert n._plugin_registered @@ -865,7 +865,7 @@ async def test_nanny_plugin_register_during_start_success(c, s, restart): try: await n.in_instantiate.wait() - register = asyncio.create_task(c.register_worker_plugin(plugin)) + register = asyncio.create_task(c.register_plugin(plugin)) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(asyncio.shield(register), timeout=0.1) n.wait_instantiate.set() @@ -898,7 +898,7 @@ async def test_nanny_plugin_register_during_start_failure(c, s, restart): start = asyncio.create_task(n.start()) await n.in_instantiate.wait() - register = asyncio.create_task(c.register_worker_plugin(plugin)) + register = asyncio.create_task(c.register_plugin(plugin)) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(asyncio.shield(register), timeout=0.1) n.wait_instantiate.set() @@ -949,7 +949,7 @@ async def test_nanny_plugin_register_nanny_killed(c, s, restart): try: plugin = DummyNannyPlugin("foo", restart=restart) await asyncio.to_thread(in_instantiate.wait) - register = asyncio.create_task(c.register_worker_plugin(plugin)) + register = asyncio.create_task(c.register_plugin(plugin)) finally: proc.kill() assert await register == {} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index cb16ab39f3..19b19ce937 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -392,9 +392,7 @@ async def test_graph_execution_width(c, s, *workers): passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots] passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1] done = [delayed(lambda r: None)(r) for r in passthrough2] - await c.register_worker_plugin( - CountData(keys=[f.key for f in roots]), name="count-roots" - ) + await c.register_plugin(CountData(keys=[f.key for f in roots]), name="count-roots") fs = c.compute(done) await wait(fs) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 65b4fef899..608d3f0013 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -311,7 +311,7 @@ async def test_chaos_rechunk(c, s, *workers): plugin = KillWorker(delay="4 s", mode="sys.exit") - await c.register_worker_plugin(plugin, name="kill") + await c.register_plugin(plugin, name="kill") da = pytest.importorskip("dask.array") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6ce513f1e2..5daa7cc69a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1699,7 +1699,7 @@ async def test_pip_install(c, s, a): with mock.patch( "distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked ) as Popen: - await c.register_worker_plugin( + await c.register_plugin( PipInstall(packages=["requests"], pip_options=["--upgrade"]) ) assert Popen.call_count == 1 @@ -1723,7 +1723,7 @@ async def test_conda_install(c, s, a): module_mock.run_command = run_command_mock module_mock.Commands.INSTALL = "INSTALL" with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): - await c.register_worker_plugin( + await c.register_plugin( CondaInstall(packages=["requests"], conda_options=["--update-deps"]) ) assert run_command_mock.call_count == 1 @@ -1756,7 +1756,7 @@ async def test_pip_install_fails(c, s, a, b): "distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked ) as Popen: with pytest.raises(RuntimeError): - await c.register_worker_plugin(PipInstall(packages=["not-a-package"])) + await c.register_plugin(PipInstall(packages=["not-a-package"])) assert Popen.call_count == 1 logs = logger.getvalue() @@ -1771,7 +1771,7 @@ async def test_conda_install_fails_when_conda_not_found(c, s, a, b): ) as logger: with mock.patch.dict("sys.modules", {"conda": None}): with pytest.raises(RuntimeError): - await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + await c.register_plugin(CondaInstall(packages=["not-a-package"])) logs = logger.getvalue() assert "install failed" in logs assert "conda could not be found" in logs @@ -1789,7 +1789,7 @@ async def test_conda_install_fails_when_conda_raises(c, s, a, b): module_mock.Commands.INSTALL = "INSTALL" with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): with pytest.raises(RuntimeError): - await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + await c.register_plugin(CondaInstall(packages=["not-a-package"])) assert run_command_mock.call_count == 1 logs = logger.getvalue() assert "install failed" in logs @@ -1807,7 +1807,7 @@ async def test_conda_install_fails_on_returncode(c, s, a, b): module_mock.Commands.INSTALL = "INSTALL" with mock.patch.dict("sys.modules", {"conda.cli.python_api": module_mock}): with pytest.raises(RuntimeError): - await c.register_worker_plugin(CondaInstall(packages=["not-a-package"])) + await c.register_plugin(CondaInstall(packages=["not-a-package"])) assert run_command_mock.call_count == 1 logs = logger.getvalue() assert "install failed" in logs @@ -1830,7 +1830,7 @@ async def test_package_install_installs_once_with_multiple_workers(c, s, a, b): ) as logger: install_mock = mock.Mock(name="install") with mock.patch.object(StubInstall, "install", install_mock): - await c.register_worker_plugin( + await c.register_plugin( StubInstall( packages=["requests"], ) @@ -1844,7 +1844,7 @@ async def test_package_install_installs_once_with_multiple_workers(c, s, a, b): @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_package_install_restarts_on_nanny(c, s, a): (addr,) = s.workers - await c.register_worker_plugin( + await c.register_plugin( StubInstall( packages=["requests"], restart=True, @@ -1869,7 +1869,7 @@ def install(self) -> None: async def test_package_install_failing_does_not_restart_on_nanny(c, s, a): (addr,) = s.workers with pytest.raises(RuntimeError): - await c.register_worker_plugin( + await c.register_plugin( FailingInstall( packages=["requests"], restart=True, @@ -2037,7 +2037,7 @@ async def test_bad_local_directory(s): @gen_cluster(client=True, nthreads=[]) async def test_taskstate_metadata(c, s): async with Worker(s.address) as a: - await c.register_worker_plugin(TaskStateMetadataPlugin()) + await c.register_plugin(TaskStateMetadataPlugin()) f = c.submit(inc, 1) await f @@ -3561,7 +3561,7 @@ def setup(self, worker): def teardown(self, worker): pass - await c.register_worker_plugin(InitWorkerNewThread()) + await c.register_plugin(InitWorkerNewThread()) async with Worker(s.address) as worker: assert await c.submit(inc, 1) == 2 assert worker.plugins[InitWorkerNewThread.name].setup_status is Status.running @@ -3671,7 +3671,7 @@ def print_stderr(*args, **kwargs): assert "" == err # After installing, output should be forwarded - await c.register_worker_plugin(plugin, "forward") + await c.register_plugin(plugin, "forward") await asyncio.sleep(0.1) # Let setup messages come in capsys.readouterr() @@ -3713,7 +3713,7 @@ def print_stderr(*args, **kwargs): # Registering the plugin is idempotent other_plugin = ForwardOutput() - await c.register_worker_plugin(other_plugin, "forward") + await c.register_plugin(other_plugin, "forward") await asyncio.sleep(0.1) # Let teardown/setup messages come in out, err = capsys.readouterr() diff --git a/distributed/worker.py b/distributed/worker.py index 8f833bff1b..e08014cab9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1863,7 +1863,6 @@ async def plugin_add( if name is None: name = _get_plugin_name(plugin) - assert name if name in self.plugins: @@ -1885,7 +1884,7 @@ async def plugin_add( return {"status": "OK"} @log_errors - async def plugin_remove(self, name: str) -> dict[str, Any]: + async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: logger.info(f"Removing Worker plugin {name}") try: plugin = self.plugins.pop(name) @@ -1894,8 +1893,7 @@ async def plugin_remove(self, name: str) -> dict[str, Any]: if isawaitable(result): result = await result except Exception as e: - msg = error_message(e) - return cast("dict[str, Any]", msg) + return error_message(e) return {"status": "OK"} From 38c6721595bddae3e20ba179a5bd6230d756ab8d Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 8 Sep 2023 15:15:06 +0200 Subject: [PATCH 3/5] Fix post-stringification info pages (#8161) --- .../dashboard/tests/test_scheduler_bokeh.py | 70 +++++++++++++++++++ distributed/http/scheduler/info.py | 41 +++++++++-- distributed/http/templates/task.html | 17 ++--- distributed/http/templates/worker.html | 4 +- 4 files changed, 117 insertions(+), 15 deletions(-) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index db857e8eb3..8c9953c6b6 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -9,6 +9,8 @@ import pytest pytest.importorskip("bokeh") +from urllib.parse import quote_plus + from bokeh.server.server import BokehTornado from tlz import first from tornado.httpclient import AsyncHTTPClient, HTTPRequest @@ -1160,6 +1162,74 @@ async def test_memory_by_key(c, s, a, b): assert mbk.source.data["nbytes"] == [x.nbytes, sys.getsizeof(1)] +@gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) +async def test_worker_info(c, s, a, b): + port = s.http_server.port + ev1 = Event() + ev2 = Event() + + def block_on_event(enter, wait): + enter.set() + wait.wait() + + f = c.submit(block_on_event, ev1, ev2, workers=[a.address]) + try: + host = f"http://127.0.0.1:{port}" + http_client = AsyncHTTPClient() + await ev1.wait() + response = await http_client.fetch( + f"{host}/info/task/foo.html", raise_error=False + ) + assert response.code == 404 + response = await http_client.fetch( + f"{host}/info/task/{quote_plus(str(f.key))}.html" + ) + assert response.code == 200 + + response = await http_client.fetch( + f"{host}/info/call-stack/foo.html", raise_error=False + ) + assert response.code == 404 + response = await http_client.fetch( + f"{host}/info/call-stack/{quote_plus(str(f.key))}.html" + ) + assert response.code == 200 + + response = await http_client.fetch(f"{host}/info/main/workers.html") + assert response.code == 200 + + response = await http_client.fetch( + f"{host}/info/worker/bar.html", raise_error=False + ) + assert response.code == 404 + response = await http_client.fetch( + f"{host}/info/call-stacks/bar.html", raise_error=False + ) + assert response.code == 404 + + response = await http_client.fetch( + f"{host}/info/logs/bar.html", raise_error=False + ) + assert response.code == 404 + + for w in [a, b]: + response = await http_client.fetch( + f"{host}/info/worker/{quote_plus(w.address)}.html" + ) + assert response.code == 200 + + response = await http_client.fetch( + f"{host}/info/call-stacks/{quote_plus(w.address)}.html" + ) + assert response.code == 200 + response = await http_client.fetch( + f"{host}/info/logs/{quote_plus(w.address)}.html" + ) + assert response.code == 200 + finally: + await ev2.set() + + @gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) async def test_aggregate_action(c, s, a, b): mbk = AggregateAction(s) diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py index bee6be4545..0290e5bddf 100644 --- a/distributed/http/scheduler/info.py +++ b/distributed/http/scheduler/info.py @@ -4,7 +4,9 @@ import logging import os import os.path +from collections.abc import Hashable from datetime import datetime +from typing import TYPE_CHECKING from tlz import first, merge from tornado import escape @@ -17,6 +19,9 @@ from distributed.metrics import time from distributed.utils import log_errors +if TYPE_CHECKING: + from distributed import Scheduler + ns = { func.__name__: func for func in [format_bytes, format_time, datetime.fromtimestamp, time] @@ -85,18 +90,28 @@ def get(self): ) +def _get_actual_scheduler_key(key: str, scheduler: Scheduler) -> Hashable: + for k in scheduler.tasks: + if str(k) == key: + return k + raise KeyError(key) + + class Task(RequestHandler): @log_errors def get(self, task): task = escape.url_unescape(task) - if task not in self.server.tasks: + + try: + requested_key = _get_actual_scheduler_key(task, self.server) + except KeyError: self.send_error(404) return self.render( "task.html", title="Task: " + task, - Task=task, + Task=requested_key, scheduler=self.server, **merge( self.server.__dict__, @@ -124,7 +139,13 @@ class WorkerLogs(RequestHandler): @log_errors async def get(self, worker): worker = escape.url_unescape(worker) - logs = await self.server.get_worker_logs(workers=[worker]) + try: + logs = await self.server.get_worker_logs(workers=[worker]) + except Exception: + if not any(worker == w.address for w in self.server.workers.values()): + self.send_error(404) + return + raise logs = logs[worker] self.render( "logs.html", @@ -138,7 +159,11 @@ class WorkerCallStacks(RequestHandler): @log_errors async def get(self, worker): worker = escape.url_unescape(worker) - keys = {ts.key for ts in self.server.workers[worker].processing} + try: + keys = {ts.key for ts in self.server.workers[worker].processing} + except KeyError: + self.send_error(404) + return call_stack = await self.server.get_call_stack(keys=keys) self.render( "call-stack.html", @@ -152,7 +177,13 @@ class TaskCallStack(RequestHandler): @log_errors async def get(self, key): key = escape.url_unescape(key) - call_stack = await self.server.get_call_stack(keys=[key]) + + try: + requested_key = _get_actual_scheduler_key(key, self.server) + except KeyError: + self.send_error(404) + return + call_stack = await self.server.get_call_stack(keys=[requested_key]) if not call_stack: self.write( "

Task not actively running. " diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 2581254804..2344c680a9 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -16,7 +16,7 @@

Task: {{ ts.key }}

Call stack - Call Stack + Call Stack {% end %} {% if ts.type %} @@ -35,7 +35,7 @@

Task: {{ ts.key }}

{% for dts in ts.waiting_on %} waiting on - {{dts.key}} + {{dts.key}} {% end %} {% end %} @@ -79,7 +79,7 @@

Dependencies

{% for dts in ts.dependencies %} - {{dts.key}} + {{dts.key}} {{ dts.state }} {% end %} @@ -95,7 +95,7 @@

Dependents

{% for dts in ts.dependents %} - {{dts.key}} + {{dts.key}} {{ dts.state }} {% end %} @@ -117,8 +117,9 @@

Workers with data

Clients with future

@@ -139,10 +140,10 @@

Transition Log

Recommended Action - {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(ts.key) %} {{ fromtimestamp(transition_time) }} - {{key}} + {{key}} {{ start }} {{ finish }} {{ stimulus_id }} @@ -156,7 +157,7 @@

Transition Log

- {{key2}} + {{key2}} {{ rec }} {% end %} diff --git a/distributed/http/templates/worker.html b/distributed/http/templates/worker.html index 9c7608cb8c..e56f8f9ce7 100644 --- a/distributed/http/templates/worker.html +++ b/distributed/http/templates/worker.html @@ -16,7 +16,7 @@

In Memory

{% for ts in ws.has_what %} - {{ts.key}} + {{ts.key}} {{format_bytes(ts.nbytes)}} {% end %} @@ -35,7 +35,7 @@

Processing

{% for ts in sorted(ws.processing, key=lambda ts: ts.priority) %} - {{ts.key}} + {{ts.key}} {{ts.priority }} {% end %} From e350c9949dca6e56187f2806e39991abe0c4dbb7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 8 Sep 2023 16:09:53 +0100 Subject: [PATCH 4/5] Fix race condition between MemorySampler and scheduler shutdown (#8172) --- distributed/diagnostics/memory_sampler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/diagnostics/memory_sampler.py b/distributed/diagnostics/memory_sampler.py index 3443b1ebba..a805bdc9c1 100644 --- a/distributed/diagnostics/memory_sampler.py +++ b/distributed/diagnostics/memory_sampler.py @@ -217,6 +217,7 @@ def sample(): def stop(self, key: str) -> list[tuple[float, int]]: """Stop sampling and return the samples""" - pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key) - pc.stop() + pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key, None) + if pc is not None: # Race condition with scheduler shutdown + pc.stop() return self.samples.pop(key) From 7744d68f042210833829a0cdef0521135761115d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 11 Sep 2023 11:15:41 -0500 Subject: [PATCH 5/5] Skip ``rechunker`` in code samples (#8178) --- distributed/distributed.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index d6285eb466..6ea2a603f5 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -282,6 +282,7 @@ distributed: - cudf - cuml - prefect + - rechunker - xarray - xgboost erred-tasks: