Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into cudf-spilling-dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca committed Sep 11, 2023
2 parents 50e626a + 7744d68 commit 98e283e
Show file tree
Hide file tree
Showing 28 changed files with 671 additions and 186 deletions.
236 changes: 188 additions & 48 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import asynccontextmanager, contextmanager, suppress
from contextvars import ContextVar
from functools import partial
from functools import partial, singledispatchmethod
from importlib.metadata import PackageNotFoundError, version
from numbers import Number
from queue import Queue as pyQueue
from typing import Any, ClassVar, Literal, NamedTuple, TypedDict
from typing import Any, Callable, ClassVar, Literal, NamedTuple, TypedDict, cast

from packaging.version import parse as parse_version
from tlz import first, groupby, merge, partition_all, valmap
Expand All @@ -46,7 +46,7 @@
)
from dask.widgets import get_template

from distributed.core import ErrorMessage
from distributed.core import ErrorMessage, OKMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for

Expand Down Expand Up @@ -75,6 +75,7 @@
from distributed.diagnostics.plugin import (
ForwardLoggingPlugin,
NannyPlugin,
SchedulerPlugin,
SchedulerUploadFile,
UploadFile,
WorkerPlugin,
Expand Down Expand Up @@ -3781,10 +3782,11 @@ def upload_file(self, filename, load: bool = True):

async def _():
results = await asyncio.gather(
self.register_scheduler_plugin(
self.register_plugin(
SchedulerUploadFile(filename, load=load), name=name
),
self.register_worker_plugin(UploadFile(filename, load=load), name=name),
# FIXME: Make scheduler plugin responsible for (de)registering worker plugin
self.register_plugin(UploadFile(filename, load=load), name=name),
)
return results[1] # Results from workers upload

Expand Down Expand Up @@ -4814,15 +4816,90 @@ async def _get_task_stream(
else:
return msgs

async def _register_scheduler_plugin(self, plugin, name, idempotent=False):
def register_plugin(
self,
plugin: NannyPlugin | SchedulerPlugin | WorkerPlugin,
name: str | None = None,
idempotent: bool = False,
):
"""Register a plugin.
See https://distributed.readthedocs.io/en/latest/plugins.html
Parameters
----------
plugin :
A nanny, scheduler, or worker plugin to register.
name :
Name for the plugin; if None, a name is taken from the
plugin instance or automatically generated if not present.
idempotent :
Do not re-register if a plugin of the given name already exists.
"""
if name is None:
name = _get_plugin_name(plugin)
assert name

return self._register_plugin(plugin, name, idempotent)

@singledispatchmethod
def _register_plugin(
self,
plugin: NannyPlugin | SchedulerPlugin | WorkerPlugin,
name: str,
idempotent: bool,
):
raise TypeError(
"Registering duck-typed plugins is not allowed. Please inherit from "
"NannyPlugin, WorkerPlugin, or SchedulerPlugin to create a plugin."
)

@_register_plugin.register
def _(self, plugin: SchedulerPlugin, name: str, idempotent: bool):
return self.sync(
self._register_scheduler_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
)

@_register_plugin.register
def _(
self, plugin: NannyPlugin, name: str, idempotent: bool
) -> dict[str, OKMessage]:
return self.sync(
self._register_nanny_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
)

@_register_plugin.register
def _(self, plugin: WorkerPlugin, name: str, idempotent: bool):
return self.sync(
self._register_worker_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
)

async def _register_scheduler_plugin(
self, plugin: SchedulerPlugin, name: str, idempotent: bool
):
return await self.scheduler.register_scheduler_plugin(
plugin=dumps(plugin),
name=name,
idempotent=idempotent,
)

def register_scheduler_plugin(self, plugin, name=None, idempotent=False):
"""Register a scheduler plugin.
def register_scheduler_plugin(
self, plugin: SchedulerPlugin, name: str | None = None, idempotent: bool = False
):
"""
Register a scheduler plugin.
.. deprecated:: 2023.9.2
Use :meth:`Client.register_plugin` instead.
See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins
Expand All @@ -4836,15 +4913,13 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False):
idempotent : bool
Do not re-register if a plugin of the given name already exists.
"""
if name is None:
name = _get_plugin_name(plugin)

return self.sync(
self._register_scheduler_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
warnings.warn(
"`Client.register_scheduler_plugin` has been deprecated; "
"please `Client.register_plugin` instead",
DeprecationWarning,
stacklevel=2,
)
return cast(OKMessage, self.register_plugin(plugin, name, idempotent))

async def _unregister_scheduler_plugin(self, name):
return await self.scheduler.unregister_scheduler_plugin(name=name)
Expand Down Expand Up @@ -4875,7 +4950,7 @@ def unregister_scheduler_plugin(self, name):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_scheduler_plugin(plugin, name='foo')
>>> client.register_plugin(plugin, name='foo')
>>> client.unregister_scheduler_plugin(name='foo')
See Also
Expand All @@ -4902,41 +4977,50 @@ def register_worker_callbacks(self, setup=None):
setup : callable(dask_worker: Worker) -> None
Function to register and run on all workers
"""
return self.register_worker_plugin(_WorkerSetupPlugin(setup))
return self.register_plugin(_WorkerSetupPlugin(setup))

async def _register_worker_plugin(self, plugin=None, name=None, nanny=None):
if nanny or nanny is None and isinstance(plugin, NannyPlugin):
if not isinstance(plugin, NannyPlugin):
warnings.warn(
"Registering duck-typed plugins has been deprecated. "
"Please make sure your plugin subclasses `NannyPlugin`.",
DeprecationWarning,
stacklevel=2,
)
method = self.scheduler.register_nanny_plugin
else:
if not isinstance(plugin, WorkerPlugin):
warnings.warn(
"Registering duck-typed plugins has been deprecated. "
"Please make sure your plugin subclasses `WorkerPlugin`.",
DeprecationWarning,
stacklevel=2,
async def _register_worker_plugin(
self, plugin: WorkerPlugin, name: str, idempotent: bool
) -> dict[str, OKMessage]:
responses = await self.scheduler.register_worker_plugin(
plugin=dumps(plugin), name=name, idempotent=idempotent
)
for response in responses.values():
if response["status"] == "error":
_, exc, tb = clean_exception(
response["exception"], response["traceback"]
)
method = self.scheduler.register_worker_plugin
assert exc
raise exc.with_traceback(tb)
return cast(dict[str, OKMessage], responses)

responses = await method(plugin=dumps(plugin), name=name)
async def _register_nanny_plugin(
self, plugin: NannyPlugin, name: str, idempotent: bool
) -> dict[str, OKMessage]:
responses = await self.scheduler.register_nanny_plugin(
plugin=dumps(plugin), name=name, idempotent=idempotent
)
for response in responses.values():
if response["status"] == "error":
_, exc, tb = clean_exception(
response["exception"], response["traceback"]
)
assert exc
raise exc.with_traceback(tb)
return responses
return cast(dict[str, OKMessage], responses)

def register_worker_plugin(self, plugin=None, name=None, nanny=None):
def register_worker_plugin(
self,
plugin: NannyPlugin | WorkerPlugin,
name: str | None = None,
nanny: bool | None = None,
):
"""
Registers a lifecycle worker plugin for all current and future workers.
.. deprecated:: 2023.9.2
Use :meth:`Client.register_plugin` instead.
This registers a new object to handle setup, task state transitions and
teardown for workers in this cluster. The plugin will instantiate
itself on all currently connected workers. It will also be run on any
Expand Down Expand Up @@ -4982,11 +5066,11 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin)
>>> client.register_plugin(plugin)
You can get access to the plugin with the ``get_worker`` function
>>> client.register_worker_plugin(other_plugin, name='my-plugin')
>>> client.register_plugin(other_plugin, name='my-plugin')
>>> def f():
... worker = get_worker()
... plugin = worker.plugins['my-plugin']
Expand All @@ -4999,14 +5083,70 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None):
distributed.WorkerPlugin
unregister_worker_plugin
"""
warnings.warn(
"`Client.register_worker_plugin` has been deprecated; "
"please use `Client.register_plugin` instead",
DeprecationWarning,
stacklevel=2,
)
if name is None:
name = _get_plugin_name(plugin)

assert name

return self.sync(
self._register_worker_plugin, plugin=plugin, name=name, nanny=nanny
)
method: Callable
if isinstance(plugin, WorkerPlugin):
method = self._register_worker_plugin
if nanny is True:
warnings.warn(
"Registering a `WorkerPlugin` as a nanny plugin is not "
"allowed, registering as a worker plugin instead. "
"To register as a nanny plugin, inherit from `NannyPlugin`.",
UserWarning,
stacklevel=2,
)
elif isinstance(plugin, NannyPlugin):
method = self._register_nanny_plugin
if nanny is False:
warnings.warn(
"Registering a `NannyPlugin` as a worker plugin is not "
"allowed, registering as a nanny plugin instead. "
"To register as a worker plugin, inherit from `WorkerPlugin`.",
UserWarning,
stacklevel=2,
)
elif isinstance(plugin, SchedulerPlugin): # type: ignore[unreachable]
if nanny:
warnings.warn(
"Registering a `SchedulerPlugin` as a nanny plugin is not "
"allowed, registering as a scheduler plugin instead. "
"To register as a nanny plugin, inherit from `NannyPlugin`.",
UserWarning,
stacklevel=2,
)
else:
warnings.warn(
"Registering a `SchedulerPlugin` as a worker plugin is not "
"allowed, registering as a scheduler plugin instead. "
"To register as a worker plugin, inherit from `WorkerPlugin`.",
UserWarning,
stacklevel=2,
)
method = self._register_scheduler_plugin
else:
warnings.warn(
"Registering duck-typed plugins has been deprecated. "
"Please make sure your plugin inherits from `NannyPlugin` "
"or `WorkerPlugin`.",
DeprecationWarning,
stacklevel=2,
)
if nanny is True:
method = self._register_nanny_plugin
else:
method = self._register_worker_plugin

return self.sync(method, plugin=plugin, name=name, idempotent=False)

async def _unregister_worker_plugin(self, name, nanny=None):
if nanny:
Expand All @@ -5030,7 +5170,7 @@ def unregister_worker_plugin(self, name, nanny=None):
Parameters
----------
name : str
Name of the plugin to unregister. See the :meth:`Client.register_worker_plugin`
Name of the plugin to unregister. See the :meth:`Client.register_plugin`
docstring for more information.
Examples
Expand All @@ -5048,12 +5188,12 @@ def unregister_worker_plugin(self, name, nanny=None):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin, name='foo')
>>> client.register_plugin(plugin, name='foo')
>>> client.unregister_worker_plugin(name='foo')
See Also
--------
register_worker_plugin
register_plugin
"""
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)

Expand Down Expand Up @@ -5190,7 +5330,7 @@ class : logging.StreamHandler
# removed and torn down (see distributed.worker.Worker.plugin_add()), so
# this is effectively idempotent, i.e., forwarding the same logger twice
# won't cause every LogRecord to be forwarded twice
return self.register_worker_plugin(
return self.register_plugin(
ForwardLoggingPlugin(logger_name, level, topic), plugin_name
)

Expand Down
Loading

0 comments on commit 98e283e

Please sign in to comment.