Skip to content

Commit

Permalink
Address Peters Review
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Oct 9, 2024
1 parent 94982f6 commit 33871ac
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 71 deletions.
3 changes: 1 addition & 2 deletions dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def cuda():
show_default=True,
help="""
Set RMM as the allocator for external libraries. Provide a comma-separated
list of libraries to set, e.g., "torch,cupy".
Supported options are: torch, cupy.""",
list of libraries to set, e.g., "torch,cupy".""",
)
@click.option(
"--rmm-release-threshold",
Expand Down
14 changes: 11 additions & 3 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class LocalCUDACluster(LocalCluster):
result in an exception.
rmm_allocator_external_lib_list: list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. If None, no external
Supported options are: ``["torch", "cupy"]``. If ``None``, no external
libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
Expand Down Expand Up @@ -271,8 +271,16 @@ def __init__(
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")

if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = []
if rmm_allocator_external_lib_list is not None and not isinstance(
rmm_allocator_external_lib_list, list
):
raise ValueError(
"rmm_allocator_external_lib_list must be a list of strings. "
"Valid examples: ['torch'], ['cupy'], or ['torch', 'cupy']. "
f"Received: {type(rmm_allocator_external_lib_list)} "
f"with value: {rmm_allocator_external_lib_list}"
)

# Set nthreads=1 when parsing mem_limit since it only depends on n_workers
logger = logging.getLogger(__name__)
self.memory_limit = parse_memory_limit(
Expand Down
67 changes: 62 additions & 5 deletions dask_cuda/plugins.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import importlib
import os
from typing import Callable, Dict

from distributed import WorkerPlugin

from .utils import (
enable_rmm_memory_for_library,
get_rmm_log_file_name,
parse_device_memory_limit,
)
from .utils import get_rmm_log_file_name, parse_device_memory_limit


class CPUAffinity(WorkerPlugin):
Expand Down Expand Up @@ -134,6 +131,66 @@ def setup(self, worker=None):
enable_rmm_memory_for_library(lib)


def enable_rmm_memory_for_library(lib_name: str) -> None:
"""Enable RMM memory pool support for a specified third-party library.
This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.
Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
Supported libraries are "cupy" and "torch".
Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class PreImport(WorkerPlugin):
def __init__(self, libraries):
if libraries is None:
Expand Down
62 changes: 1 addition & 61 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import suppress
from functools import singledispatch
from multiprocessing import cpu_count
from typing import Callable, Dict, Optional
from typing import Optional

import click
import numpy as np
Expand Down Expand Up @@ -767,66 +767,6 @@ def get_rmm_memory_resource_stack(mr) -> list:
return None


def enable_rmm_memory_for_library(lib_name: str) -> None:
"""
Enable RMM memory pool support for a specified third-party library.
This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.
Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class CommaSeparatedChoice(click.Choice):
def convert(self, value, param, ctx):
values = [v.strip() for v in value.split(",")]
Expand Down

0 comments on commit 33871ac

Please sign in to comment.