diff --git a/cpp/src/utils/callback_notifier.cpp b/cpp/src/utils/callback_notifier.cpp index 5a521d64..aefef066 100644 --- a/cpp/src/utils/callback_notifier.cpp +++ b/cpp/src/utils/callback_notifier.cpp @@ -65,7 +65,7 @@ bool CallbackNotifier::wait(uint64_t period, bool ret = false; for (size_t i = 0; i < attempts; ++i) { ret = _conditionVariable.wait_for( - lock, std::chrono::duration(period), [this]() { + lock, std::chrono::duration(signalInterval), [this]() { return _flag.load(std::memory_order_relaxed) == true; }); if (signalWorkerFunction) signalWorkerFunction(); diff --git a/python/distributed-ucxx/distributed_ucxx/__init__.py b/python/distributed-ucxx/distributed_ucxx/__init__.py index a73aa8fb..afdbf31e 100644 --- a/python/distributed-ucxx/distributed_ucxx/__init__.py +++ b/python/distributed-ucxx/distributed_ucxx/__init__.py @@ -1,5 +1,4 @@ from .ucxx import UCXXBackend, UCXXConnector, UCXXListener # noqa: F401 -from . import distributed_patches # noqa: F401 from ._version import __git_commit__, __version__ diff --git a/python/distributed-ucxx/distributed_ucxx/distributed_patches.py b/python/distributed-ucxx/distributed_ucxx/distributed_patches.py deleted file mode 100644 index 4887c230..00000000 --- a/python/distributed-ucxx/distributed_ucxx/distributed_patches.py +++ /dev/null @@ -1,41 +0,0 @@ -from distributed import Scheduler, Worker -from distributed.utils import log_errors - -import ucxx - -_scheduler_close = Scheduler.close -_worker_close = Worker.close - - -def _stop_notifier_thread_and_progress_tasks(): - ucxx.stop_notifier_thread() - ucxx.core._get_ctx().progress_tasks.clear() - - -async def _scheduler_close_ucxx(*args, **kwargs): - scheduler = args[0] # args[0] == self - - await _scheduler_close(*args, **kwargs) - - is_ucxx = any([addr.startswith("ucxx") for addr in scheduler._start_address]) - - if is_ucxx: - _stop_notifier_thread_and_progress_tasks() - - -@log_errors -async def _worker_close_ucxx(*args, **kwargs): - # This patch is insufficient for `dask worker` when `--nworkers=1` (default) or - # `--no-nanny` is specified because there's no good way to detect that the - # `distributed.Worker.close()` method should stop the notifier thread. - - worker = args[0] # args[0] == self - - await _worker_close(*args, **kwargs) - - if worker._protocol.startswith("ucxx") and worker.nanny is not None: - _stop_notifier_thread_and_progress_tasks() - - -Scheduler.close = _scheduler_close_ucxx -Worker.close = _worker_close_ucxx diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index a9a38b8b..1f5fc1df 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -8,6 +8,7 @@ from __future__ import annotations import functools +import itertools import logging import os import struct @@ -90,6 +91,95 @@ def synchronize_stream(stream=0): stream.synchronize() +def make_register(): + count = itertools.count() + + def register() -> int: + """Register a Dask resource with the UCXX context. + + Register a Dask resource with the UCXX context and keep track of it with the + use of a unique ID for the resource. The resource ID is later used to + deregister the resource from the UCXX context calling + `_deregister_dask_resource(resource_id)`, which stops the notifier thread + and progress tasks when no more UCXX resources are alive. + + Returns + ------- + resource_id: int + The ID of the registered resource that should be used with + `_deregister_dask_resource` during stop/destruction of the resource. + """ + ctx = ucxx.core._get_ctx() + with ctx._dask_resources_lock: + resource_id = next(count) + ctx._dask_resources.add(resource_id) + ctx.start_notifier_thread() + ctx.continuous_ucx_progress() + return resource_id + + return register + + +_register_dask_resource = make_register() + +del make_register + + +def _deregister_dask_resource(resource_id): + """Deregister a Dask resource with the UCXX context. + + Deregister a Dask resource from the UCXX context with given ID, and if no + resources remain after deregistration, stop the notifier thread and progress + tasks. + + Parameters + ---------- + resource_id: int + The unique ID of the resource returned by `_register_dask_resource` upon + registration. + """ + if ucxx.core._ctx is None: + # Prevent creation of context if it was already destroyed, all + # registered references are already gone. + return + + ctx = ucxx.core._get_ctx() + + # Check if the attribute exists first, in tests the UCXX context may have + # been reset before some resources are deregistered. + if hasattr(ctx, "_dask_resources_lock"): + with ctx._dask_resources_lock: + try: + ctx._dask_resources.remove(resource_id) + except KeyError: + pass + + # Stop notifier thread and progress tasks if no Dask resources using + # UCXX communicators are running anymore. + if len(ctx._dask_resources) == 0: + ctx.stop_notifier_thread() + ctx.progress_tasks.clear() + + +def _allocate_dask_resources_tracker() -> None: + """Allocate Dask resources tracker. + + Allocate a Dask resources tracker in the UCXX context. This is useful to + track Distributed communicators so that progress and notifier threads can + be cleanly stopped when no UCXX communicators are alive anymore. + """ + ctx = ucxx.core._get_ctx() + if not hasattr(ctx, "_dask_resources"): + # TODO: Move the `Lock` to a file/module-level variable for true + # lock-safety. The approach implemented below could cause race + # conditions if this function is called simultaneously by multiple + # threads. + from threading import Lock + + ctx._dask_resources = set() + ctx._dask_resources_lock = Lock() + + def init_once(): global ucxx, device_array global ucx_create_endpoint, ucx_create_listener @@ -97,6 +187,11 @@ def init_once(): global multi_buffer if ucxx is not None: + # Ensure reallocation of Dask resources tracker if the UCXX context was + # reset since the previous `init_once()` call. This may happen in tests, + # where the `ucxx_loop` fixture will reset the context after each test. + _allocate_dask_resources_tracker() + return # remove/process dask.ucx flags for valid ucx options @@ -159,6 +254,7 @@ def init_once(): # environment, so the user's external environment can safely # override things here. ucxx.init(options=ucx_config, env_takes_precedence=True) + _allocate_dask_resources_tracker() pool_size_str = dask.config.get("distributed.rmm.pool-size") @@ -279,8 +375,12 @@ def __init__( # type: ignore[no-untyped-def] else: self._has_close_callback = False + self._resource_id = _register_dask_resource() + logger.debug("UCX.__init__ %s", self) + weakref.finalize(self, _deregister_dask_resource, self._resource_id) + def __del__(self) -> None: self.abort() @@ -488,6 +588,7 @@ def abort(self): if self._ep is not None: self._ep.abort() self._ep = None + _deregister_dask_resource(self._resource_id) def closed(self): if self._has_close_callback is True: @@ -522,6 +623,7 @@ async def connect( init_once() try: + self._resource_id = _register_dask_resource() ep = await ucxx.create_endpoint(ip, port) except ( ucxx.exceptions.UCXCloseError, @@ -532,6 +634,8 @@ async def connect( ucxx.exceptions.UCXUnreachableError, ): raise CommClosedError("Connection closed before handshake completed") + finally: + _deregister_dask_resource(self._resource_id) return self.comm_class( ep, local_addr="", @@ -589,10 +693,13 @@ async def serve_forever(client_ep): await self.comm_handler(ucx) init_once() + self._resource_id = _register_dask_resource() + weakref.finalize(self, _deregister_dask_resource, self._resource_id) self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) def stop(self): self.ucxx_server = None + _deregister_dask_resource(self._resource_id) def get_host_port(self): # TODO: TCP raises if this hasn't started yet. diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 78a33f2f..0b08ad93 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -40,7 +40,7 @@ def __init__( enable_python_future=None, exchange_peer_info_timeout=10.0, ): - self.progress_tasks = [] + self.progress_tasks = dict() self.notifier_thread_q = None self.notifier_thread = None self._listener_active_clients = ActiveClients() @@ -197,7 +197,7 @@ def worker_address(self): return self.worker.address def start_notifier_thread(self): - if self.worker.enable_python_future: + if self.worker.enable_python_future and self.notifier_thread is None: logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread") loop = get_event_loop() self.notifier_thread_q = Queue() @@ -234,6 +234,7 @@ def stop_notifier_thread(self): # call otherwise. self.notifier_thread.join(timeout=0.01) if not self.notifier_thread.is_alive(): + self.notifier_thread = None break logger.debug("Notifier thread stopped") else: @@ -469,7 +470,7 @@ def continuous_ucx_progress(self, event_loop=None): elif self.progress_mode == "blocking": task = BlockingMode(self.worker, loop) - self.progress_tasks.append(task) + self.progress_tasks[loop] = task def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h)