Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into python-async-blocking-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Oct 4, 2024
2 parents 4f7a7f2 + 1f4e508 commit 23bb0bb
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 46 deletions.
2 changes: 1 addition & 1 deletion cpp/src/utils/callback_notifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bool CallbackNotifier::wait(uint64_t period,
bool ret = false;
for (size_t i = 0; i < attempts; ++i) {
ret = _conditionVariable.wait_for(
lock, std::chrono::duration<uint64_t, std::nano>(period), [this]() {
lock, std::chrono::duration<uint64_t, std::nano>(signalInterval), [this]() {
return _flag.load(std::memory_order_relaxed) == true;
});
if (signalWorkerFunction) signalWorkerFunction();
Expand Down
1 change: 0 additions & 1 deletion python/distributed-ucxx/distributed_ucxx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .ucxx import UCXXBackend, UCXXConnector, UCXXListener # noqa: F401
from . import distributed_patches # noqa: F401


from ._version import __git_commit__, __version__
41 changes: 0 additions & 41 deletions python/distributed-ucxx/distributed_ucxx/distributed_patches.py

This file was deleted.

107 changes: 107 additions & 0 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import functools
import itertools
import logging
import os
import struct
Expand Down Expand Up @@ -90,13 +91,107 @@ def synchronize_stream(stream=0):
stream.synchronize()


def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the UCXX context.
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.
Returns
-------
resource_id: int
The ID of the registered resource that should be used with
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
with ctx._dask_resources_lock:
resource_id = next(count)
ctx._dask_resources.add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id

return register


_register_dask_resource = make_register()

del make_register


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource with the UCXX context.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.
Parameters
----------
resource_id: int
The unique ID of the resource returned by `_register_dask_resource` upon
registration.
"""
if ucxx.core._ctx is None:
# Prevent creation of context if it was already destroyed, all
# registered references are already gone.
return

ctx = ucxx.core._get_ctx()

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()


def _allocate_dask_resources_tracker() -> None:
"""Allocate Dask resources tracker.
Allocate a Dask resources tracker in the UCXX context. This is useful to
track Distributed communicators so that progress and notifier threads can
be cleanly stopped when no UCXX communicators are alive anymore.
"""
ctx = ucxx.core._get_ctx()
if not hasattr(ctx, "_dask_resources"):
# TODO: Move the `Lock` to a file/module-level variable for true
# lock-safety. The approach implemented below could cause race
# conditions if this function is called simultaneously by multiple
# threads.
from threading import Lock

ctx._dask_resources = set()
ctx._dask_resources_lock = Lock()


def init_once():
global ucxx, device_array
global ucx_create_endpoint, ucx_create_listener
global pre_existing_cuda_context, cuda_context_created
global multi_buffer

if ucxx is not None:
# Ensure reallocation of Dask resources tracker if the UCXX context was
# reset since the previous `init_once()` call. This may happen in tests,
# where the `ucxx_loop` fixture will reset the context after each test.
_allocate_dask_resources_tracker()

return

# remove/process dask.ucx flags for valid ucx options
Expand Down Expand Up @@ -159,6 +254,7 @@ def init_once():
# environment, so the user's external environment can safely
# override things here.
ucxx.init(options=ucx_config, env_takes_precedence=True)
_allocate_dask_resources_tracker()

pool_size_str = dask.config.get("distributed.rmm.pool-size")

Expand Down Expand Up @@ -279,8 +375,12 @@ def __init__( # type: ignore[no-untyped-def]
else:
self._has_close_callback = False

self._resource_id = _register_dask_resource()

logger.debug("UCX.__init__ %s", self)

weakref.finalize(self, _deregister_dask_resource, self._resource_id)

def __del__(self) -> None:
self.abort()

Expand Down Expand Up @@ -488,6 +588,7 @@ def abort(self):
if self._ep is not None:
self._ep.abort()
self._ep = None
_deregister_dask_resource(self._resource_id)

def closed(self):
if self._has_close_callback is True:
Expand Down Expand Up @@ -522,6 +623,7 @@ async def connect(
init_once()

try:
self._resource_id = _register_dask_resource()
ep = await ucxx.create_endpoint(ip, port)
except (
ucxx.exceptions.UCXCloseError,
Expand All @@ -532,6 +634,8 @@ async def connect(
ucxx.exceptions.UCXUnreachableError,
):
raise CommClosedError("Connection closed before handshake completed")
finally:
_deregister_dask_resource(self._resource_id)
return self.comm_class(
ep,
local_addr="",
Expand Down Expand Up @@ -589,10 +693,13 @@ async def serve_forever(client_ep):
await self.comm_handler(ucx)

init_once()
self._resource_id = _register_dask_resource()
weakref.finalize(self, _deregister_dask_resource, self._resource_id)
self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port)

def stop(self):
self.ucxx_server = None
_deregister_dask_resource(self._resource_id)

def get_host_port(self):
# TODO: TCP raises if this hasn't started yet.
Expand Down
7 changes: 4 additions & 3 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
enable_python_future=None,
exchange_peer_info_timeout=10.0,
):
self.progress_tasks = []
self.progress_tasks = dict()
self.notifier_thread_q = None
self.notifier_thread = None
self._listener_active_clients = ActiveClients()
Expand Down Expand Up @@ -197,7 +197,7 @@ def worker_address(self):
return self.worker.address

def start_notifier_thread(self):
if self.worker.enable_python_future:
if self.worker.enable_python_future and self.notifier_thread is None:
logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread")
loop = get_event_loop()
self.notifier_thread_q = Queue()
Expand Down Expand Up @@ -234,6 +234,7 @@ def stop_notifier_thread(self):
# call otherwise.
self.notifier_thread.join(timeout=0.01)
if not self.notifier_thread.is_alive():
self.notifier_thread = None
break
logger.debug("Notifier thread stopped")
else:
Expand Down Expand Up @@ -469,7 +470,7 @@ def continuous_ucx_progress(self, event_loop=None):
elif self.progress_mode == "blocking":
task = BlockingMode(self.worker, loop)

self.progress_tasks.append(task)
self.progress_tasks[loop] = task

def get_ucp_worker(self):
"""Returns the underlying UCP worker handle (ucp_worker_h)
Expand Down

0 comments on commit 23bb0bb

Please sign in to comment.