diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 7e2c8eea75..669f915292 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -442,6 +442,9 @@ def __init__( self.name = f"{self.__class__.__name__}-{uuid.uuid4()}" async def start(self, scheduler: Scheduler) -> None: + from distributed.core import clean_exception + from distributed.protocol.serialize import Serialized, deserialize + self._scheduler = scheduler if InstallPlugin._lock is None: @@ -452,7 +455,7 @@ async def start(self, scheduler: Scheduler) -> None: if self.restart_workers: nanny_plugin = _InstallNannyPlugin(self._install_fn, self.name) - await scheduler.register_nanny_plugin( + responses = await scheduler.register_nanny_plugin( comm=None, plugin=dumps(nanny_plugin), name=self.name, @@ -460,12 +463,21 @@ async def start(self, scheduler: Scheduler) -> None: ) else: worker_plugin = _InstallWorkerPlugin(self._install_fn, self.name) - await scheduler.register_worker_plugin( + responses = await scheduler.register_worker_plugin( comm=None, plugin=dumps(worker_plugin), name=self.name, idempotent=True, ) + for response in responses.values(): + if response["status"] == "error": + response = { # type: ignore[unreachable] + k: deserialize(v.header, v.frames) + for k, v in response.items() + if isinstance(v, Serialized) + } + _, exc, tb = clean_exception(**response) + raise exc.with_traceback(tb) async def close(self) -> None: assert InstallPlugin._lock is not None @@ -563,7 +575,6 @@ async def setup(self, worker): await Semaphore( max_leases=1, name=socket.gethostname(), - register=True, scheduler_rpc=worker.scheduler, loop=worker.loop, ) diff --git a/distributed/diagnostics/tests/test_install_plugin.py b/distributed/diagnostics/tests/test_install_plugin.py index 26e3a6f937..3ebdfefdef 100644 --- a/distributed/diagnostics/tests/test_install_plugin.py +++ b/distributed/diagnostics/tests/test_install_plugin.py @@ -140,6 +140,21 @@ async def test_conda_install_fails_on_returncode(c, s, a, b): assert "install failed" in logs +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_package_install_on_worker(c, s, a): + (addr,) = s.workers + + await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False)) + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_package_install_on_nanny(c, s, a): + (addr,) = s.workers + await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False)) + + @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_package_install_restarts_on_nanny(c, s, a):