From 1f4e508423c2916df5f0f622bbd469dd549718d2 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 4 Oct 2024 17:32:00 +0200 Subject: [PATCH] Rewrite distributed-ucxx comms resources tracking (#291) Rewrite tracking of distributed-ucxx communicator resources providing more robust lifetime tracking, allowing better control over termination of notifier and progress threads. This seems to have resolved deadlock issues in Distributed for good this time and if not, it definitely improved the status substantially. Another benefit is the removal of `distributed_patches.py`, meaning there's no need anymore to monkey-patch Distributed resources to do any tracking, and instead do that on the communication resources alone which are entirely controlled by `distributed-ucxx`, in the long-run it will also cause less breakage since we were previously bound to start/stop mechanisms in Distributed to remain unchanged. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/ucxx/pull/291 --- cpp/src/utils/callback_notifier.cpp | 2 +- .../distributed_ucxx/__init__.py | 1 - .../distributed_ucxx/distributed_patches.py | 41 ------- .../distributed-ucxx/distributed_ucxx/ucxx.py | 107 ++++++++++++++++++ .../ucxx/_lib_async/application_context.py | 7 +- 5 files changed, 112 insertions(+), 46 deletions(-) delete mode 100644 python/distributed-ucxx/distributed_ucxx/distributed_patches.py diff --git a/cpp/src/utils/callback_notifier.cpp b/cpp/src/utils/callback_notifier.cpp index 5a521d64..aefef066 100644 --- a/cpp/src/utils/callback_notifier.cpp +++ b/cpp/src/utils/callback_notifier.cpp @@ -65,7 +65,7 @@ bool CallbackNotifier::wait(uint64_t period, bool ret = false; for (size_t i = 0; i < attempts; ++i) { ret = _conditionVariable.wait_for( - lock, std::chrono::duration(period), [this]() { + lock, std::chrono::duration(signalInterval), [this]() { return _flag.load(std::memory_order_relaxed) == true; }); if (signalWorkerFunction) signalWorkerFunction(); diff --git a/python/distributed-ucxx/distributed_ucxx/__init__.py b/python/distributed-ucxx/distributed_ucxx/__init__.py index a73aa8fb..afdbf31e 100644 --- a/python/distributed-ucxx/distributed_ucxx/__init__.py +++ b/python/distributed-ucxx/distributed_ucxx/__init__.py @@ -1,5 +1,4 @@ from .ucxx import UCXXBackend, UCXXConnector, UCXXListener # noqa: F401 -from . import distributed_patches # noqa: F401 from ._version import __git_commit__, __version__ diff --git a/python/distributed-ucxx/distributed_ucxx/distributed_patches.py b/python/distributed-ucxx/distributed_ucxx/distributed_patches.py deleted file mode 100644 index 4887c230..00000000 --- a/python/distributed-ucxx/distributed_ucxx/distributed_patches.py +++ /dev/null @@ -1,41 +0,0 @@ -from distributed import Scheduler, Worker -from distributed.utils import log_errors - -import ucxx - -_scheduler_close = Scheduler.close -_worker_close = Worker.close - - -def _stop_notifier_thread_and_progress_tasks(): - ucxx.stop_notifier_thread() - ucxx.core._get_ctx().progress_tasks.clear() - - -async def _scheduler_close_ucxx(*args, **kwargs): - scheduler = args[0] # args[0] == self - - await _scheduler_close(*args, **kwargs) - - is_ucxx = any([addr.startswith("ucxx") for addr in scheduler._start_address]) - - if is_ucxx: - _stop_notifier_thread_and_progress_tasks() - - -@log_errors -async def _worker_close_ucxx(*args, **kwargs): - # This patch is insufficient for `dask worker` when `--nworkers=1` (default) or - # `--no-nanny` is specified because there's no good way to detect that the - # `distributed.Worker.close()` method should stop the notifier thread. - - worker = args[0] # args[0] == self - - await _worker_close(*args, **kwargs) - - if worker._protocol.startswith("ucxx") and worker.nanny is not None: - _stop_notifier_thread_and_progress_tasks() - - -Scheduler.close = _scheduler_close_ucxx -Worker.close = _worker_close_ucxx diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index a9a38b8b..1f5fc1df 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -8,6 +8,7 @@ from __future__ import annotations import functools +import itertools import logging import os import struct @@ -90,6 +91,95 @@ def synchronize_stream(stream=0): stream.synchronize() +def make_register(): + count = itertools.count() + + def register() -> int: + """Register a Dask resource with the UCXX context. + + Register a Dask resource with the UCXX context and keep track of it with the + use of a unique ID for the resource. The resource ID is later used to + deregister the resource from the UCXX context calling + `_deregister_dask_resource(resource_id)`, which stops the notifier thread + and progress tasks when no more UCXX resources are alive. + + Returns + ------- + resource_id: int + The ID of the registered resource that should be used with + `_deregister_dask_resource` during stop/destruction of the resource. + """ + ctx = ucxx.core._get_ctx() + with ctx._dask_resources_lock: + resource_id = next(count) + ctx._dask_resources.add(resource_id) + ctx.start_notifier_thread() + ctx.continuous_ucx_progress() + return resource_id + + return register + + +_register_dask_resource = make_register() + +del make_register + + +def _deregister_dask_resource(resource_id): + """Deregister a Dask resource with the UCXX context. + + Deregister a Dask resource from the UCXX context with given ID, and if no + resources remain after deregistration, stop the notifier thread and progress + tasks. + + Parameters + ---------- + resource_id: int + The unique ID of the resource returned by `_register_dask_resource` upon + registration. + """ + if ucxx.core._ctx is None: + # Prevent creation of context if it was already destroyed, all + # registered references are already gone. + return + + ctx = ucxx.core._get_ctx() + + # Check if the attribute exists first, in tests the UCXX context may have + # been reset before some resources are deregistered. + if hasattr(ctx, "_dask_resources_lock"): + with ctx._dask_resources_lock: + try: + ctx._dask_resources.remove(resource_id) + except KeyError: + pass + + # Stop notifier thread and progress tasks if no Dask resources using + # UCXX communicators are running anymore. + if len(ctx._dask_resources) == 0: + ctx.stop_notifier_thread() + ctx.progress_tasks.clear() + + +def _allocate_dask_resources_tracker() -> None: + """Allocate Dask resources tracker. + + Allocate a Dask resources tracker in the UCXX context. This is useful to + track Distributed communicators so that progress and notifier threads can + be cleanly stopped when no UCXX communicators are alive anymore. + """ + ctx = ucxx.core._get_ctx() + if not hasattr(ctx, "_dask_resources"): + # TODO: Move the `Lock` to a file/module-level variable for true + # lock-safety. The approach implemented below could cause race + # conditions if this function is called simultaneously by multiple + # threads. + from threading import Lock + + ctx._dask_resources = set() + ctx._dask_resources_lock = Lock() + + def init_once(): global ucxx, device_array global ucx_create_endpoint, ucx_create_listener @@ -97,6 +187,11 @@ def init_once(): global multi_buffer if ucxx is not None: + # Ensure reallocation of Dask resources tracker if the UCXX context was + # reset since the previous `init_once()` call. This may happen in tests, + # where the `ucxx_loop` fixture will reset the context after each test. + _allocate_dask_resources_tracker() + return # remove/process dask.ucx flags for valid ucx options @@ -159,6 +254,7 @@ def init_once(): # environment, so the user's external environment can safely # override things here. ucxx.init(options=ucx_config, env_takes_precedence=True) + _allocate_dask_resources_tracker() pool_size_str = dask.config.get("distributed.rmm.pool-size") @@ -279,8 +375,12 @@ def __init__( # type: ignore[no-untyped-def] else: self._has_close_callback = False + self._resource_id = _register_dask_resource() + logger.debug("UCX.__init__ %s", self) + weakref.finalize(self, _deregister_dask_resource, self._resource_id) + def __del__(self) -> None: self.abort() @@ -488,6 +588,7 @@ def abort(self): if self._ep is not None: self._ep.abort() self._ep = None + _deregister_dask_resource(self._resource_id) def closed(self): if self._has_close_callback is True: @@ -522,6 +623,7 @@ async def connect( init_once() try: + self._resource_id = _register_dask_resource() ep = await ucxx.create_endpoint(ip, port) except ( ucxx.exceptions.UCXCloseError, @@ -532,6 +634,8 @@ async def connect( ucxx.exceptions.UCXUnreachableError, ): raise CommClosedError("Connection closed before handshake completed") + finally: + _deregister_dask_resource(self._resource_id) return self.comm_class( ep, local_addr="", @@ -589,10 +693,13 @@ async def serve_forever(client_ep): await self.comm_handler(ucx) init_once() + self._resource_id = _register_dask_resource() + weakref.finalize(self, _deregister_dask_resource, self._resource_id) self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) def stop(self): self.ucxx_server = None + _deregister_dask_resource(self._resource_id) def get_host_port(self): # TODO: TCP raises if this hasn't started yet. diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index dd94a112..e91b91e9 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -40,7 +40,7 @@ def __init__( enable_python_future=None, exchange_peer_info_timeout=10.0, ): - self.progress_tasks = [] + self.progress_tasks = dict() self.notifier_thread_q = None self.notifier_thread = None self._listener_active_clients = ActiveClients() @@ -194,7 +194,7 @@ def worker_address(self): return self.worker.address def start_notifier_thread(self): - if self.worker.enable_python_future: + if self.worker.enable_python_future and self.notifier_thread is None: logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread") loop = get_event_loop() self.notifier_thread_q = Queue() @@ -231,6 +231,7 @@ def stop_notifier_thread(self): # call otherwise. self.notifier_thread.join(timeout=0.01) if not self.notifier_thread.is_alive(): + self.notifier_thread = None break logger.debug("Notifier thread stopped") else: @@ -464,7 +465,7 @@ def continuous_ucx_progress(self, event_loop=None): elif self.progress_mode == "polling": task = PollingMode(self.worker, loop) - self.progress_tasks.append(task) + self.progress_tasks[loop] = task def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h)