Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into python-async-blocking-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Oct 7, 2024
2 parents 23bb0bb + a7d36f5 commit 01cbe8a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 53 deletions.
35 changes: 30 additions & 5 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ concurrency:
jobs:
pr-builder:
needs:
- changed-files
- checks
- conda-cpp-build
- docs-build
Expand All @@ -25,6 +26,25 @@ jobs:
- wheel-tests-distributed-ucxx
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: always()
with:
needs: ${{ toJSON(needs) }}
changed-files:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
files_yaml: |
test_cpp:
- '**'
- '!.pre-commit-config.yaml'
- '!README.md'
- '!docs/**'
- '!python/**'
test_python:
- '**'
- '!.pre-commit-config.yaml'
- '!README.md'
- '!docs/**'
checks:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
Expand All @@ -47,23 +67,26 @@ jobs:
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
conda-cpp-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
conda-python-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
conda-python-distributed-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
script: "ci/test_python_distributed.sh"
Expand All @@ -83,9 +106,10 @@ jobs:
build_type: pull-request
script: ci/build_wheel_ucxx.sh
wheel-tests-ucxx:
needs: wheel-build-ucxx
needs: [wheel-build-ucxx, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
Expand All @@ -98,9 +122,10 @@ jobs:
build_type: pull-request
script: ci/build_wheel_distributed_ucxx.sh
wheel-tests-distributed-ucxx:
needs: [wheel-build-ucxx, wheel-build-distributed-ucxx]
needs: [wheel-build-ucxx, wheel-build-distributed-ucxx, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
Expand Down
84 changes: 36 additions & 48 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import struct
import weakref
from collections.abc import Awaitable, Callable, Collection
from threading import Lock
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

Expand Down Expand Up @@ -49,6 +50,13 @@
pre_existing_cuda_context = False
cuda_context_created = False
multi_buffer = None
# Lock protecting access to _resources dict
_resources_lock = Lock()
# Mapping from UCXX context handles to sets of registered dask resource IDs
# Used to track when there are no more users of the context, at which point
# its progress task and notification thread can be shut down.
# See _register_dask_resource and _deregister_dask_resource.
_resources = dict()


_warning_suffix = (
Expand Down Expand Up @@ -95,13 +103,13 @@ def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the UCXX context.
"""Register a Dask resource with the resource tracker.
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.
Generate a unique ID for the resource and register it with the resource
tracker. The resource ID is later used to deregister the resource from
the tracker calling `_deregister_dask_resource(resource_id)`, which
stops the notifier thread and progress tasks when no more UCXX resources
are alive.
Returns
-------
Expand All @@ -110,9 +118,13 @@ def register() -> int:
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
with ctx._dask_resources_lock:
handle = ctx.context.handle
with _resources_lock:
if handle not in _resources:
_resources[handle] = set()

resource_id = next(count)
ctx._dask_resources.add(resource_id)
_resources[handle].add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id
Expand All @@ -126,11 +138,11 @@ def register() -> int:


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource with the UCXX context.
"""Deregister a Dask resource from the resource tracker.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.
Deregister a Dask resource from the resource tracker with given ID, and if
no resources remain after deregistration, stop the notifier thread and
progress tasks.
Parameters
----------
Expand All @@ -144,40 +156,22 @@ def _deregister_dask_resource(resource_id):
return

ctx = ucxx.core._get_ctx()
handle = ctx.context.handle

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()


def _allocate_dask_resources_tracker() -> None:
"""Allocate Dask resources tracker.
Allocate a Dask resources tracker in the UCXX context. This is useful to
track Distributed communicators so that progress and notifier threads can
be cleanly stopped when no UCXX communicators are alive anymore.
"""
ctx = ucxx.core._get_ctx()
if not hasattr(ctx, "_dask_resources"):
# TODO: Move the `Lock` to a file/module-level variable for true
# lock-safety. The approach implemented below could cause race
# conditions if this function is called simultaneously by multiple
# threads.
from threading import Lock
with _resources_lock:
try:
_resources[handle].remove(resource_id)
except KeyError:
pass

ctx._dask_resources = set()
ctx._dask_resources_lock = Lock()
# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if handle in _resources and len(_resources[handle]) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
del _resources[handle]


def init_once():
Expand All @@ -187,11 +181,6 @@ def init_once():
global multi_buffer

if ucxx is not None:
# Ensure reallocation of Dask resources tracker if the UCXX context was
# reset since the previous `init_once()` call. This may happen in tests,
# where the `ucxx_loop` fixture will reset the context after each test.
_allocate_dask_resources_tracker()

return

# remove/process dask.ucx flags for valid ucx options
Expand Down Expand Up @@ -254,7 +243,6 @@ def init_once():
# environment, so the user's external environment can safely
# override things here.
ucxx.init(options=ucx_config, env_takes_precedence=True)
_allocate_dask_resources_tracker()

pool_size_str = dask.config.get("distributed.rmm.pool-size")

Expand Down

0 comments on commit 01cbe8a

Please sign in to comment.