Skip to content

Commit

Permalink
Fix PipInstall plugin on Worker (#8839)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Aug 20, 2024
1 parent f5c30e8 commit 30e01fb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
17 changes: 14 additions & 3 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ def __init__(
self.name = f"{self.__class__.__name__}-{uuid.uuid4()}"

async def start(self, scheduler: Scheduler) -> None:
from distributed.core import clean_exception
from distributed.protocol.serialize import Serialized, deserialize

self._scheduler = scheduler

if InstallPlugin._lock is None:
Expand All @@ -452,20 +455,29 @@ async def start(self, scheduler: Scheduler) -> None:

if self.restart_workers:
nanny_plugin = _InstallNannyPlugin(self._install_fn, self.name)
await scheduler.register_nanny_plugin(
responses = await scheduler.register_nanny_plugin(
comm=None,
plugin=dumps(nanny_plugin),
name=self.name,
idempotent=True,
)
else:
worker_plugin = _InstallWorkerPlugin(self._install_fn, self.name)
await scheduler.register_worker_plugin(
responses = await scheduler.register_worker_plugin(
comm=None,
plugin=dumps(worker_plugin),
name=self.name,
idempotent=True,
)
for response in responses.values():
if response["status"] == "error":
response = { # type: ignore[unreachable]
k: deserialize(v.header, v.frames)
for k, v in response.items()
if isinstance(v, Serialized)
}
_, exc, tb = clean_exception(**response)
raise exc.with_traceback(tb)

async def close(self) -> None:
assert InstallPlugin._lock is not None
Expand Down Expand Up @@ -563,7 +575,6 @@ async def setup(self, worker):
await Semaphore(
max_leases=1,
name=socket.gethostname(),
register=True,
scheduler_rpc=worker.scheduler,
loop=worker.loop,
)
Expand Down
15 changes: 15 additions & 0 deletions distributed/diagnostics/tests/test_install_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ async def test_conda_install_fails_on_returncode(c, s, a, b):
assert "install failed" in logs


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_package_install_on_worker(c, s, a):
(addr,) = s.workers

await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False))


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_package_install_on_nanny(c, s, a):
(addr,) = s.workers
await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False))


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_package_install_restarts_on_nanny(c, s, a):
Expand Down

0 comments on commit 30e01fb

Please sign in to comment.