From 33871acd1ace384e1237f2caaf4a041a2c1b7cbf Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 9 Oct 2024 16:35:51 -0700 Subject: [PATCH] Address Peters Review --- dask_cuda/cli.py | 3 +- dask_cuda/local_cuda_cluster.py | 14 +++++-- dask_cuda/plugins.py | 67 ++++++++++++++++++++++++++++++--- dask_cuda/utils.py | 62 +----------------------------- 4 files changed, 75 insertions(+), 71 deletions(-) diff --git a/dask_cuda/cli.py b/dask_cuda/cli.py index ea90b96e..8101f020 100644 --- a/dask_cuda/cli.py +++ b/dask_cuda/cli.py @@ -172,8 +172,7 @@ def cuda(): show_default=True, help=""" Set RMM as the allocator for external libraries. Provide a comma-separated - list of libraries to set, e.g., "torch,cupy". - Supported options are: torch, cupy.""", + list of libraries to set, e.g., "torch,cupy".""", ) @click.option( "--rmm-release-threshold", diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 959f243e..d68547c1 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -145,7 +145,7 @@ class LocalCUDACluster(LocalCluster): result in an exception. rmm_allocator_external_lib_list: list or None, default None List of external libraries for which to set RMM as the allocator. - Supported options are: ``["torch", "cupy"]``. If None, no external + Supported options are: ``["torch", "cupy"]``. If ``None``, no external libraries will use RMM as their allocator. rmm_release_threshold: int, str or None, default None When ``rmm.async is True`` and the pool size grows beyond this value, unused @@ -271,8 +271,16 @@ def __init__( if n_workers < 1: raise ValueError("Number of workers cannot be less than 1.") - if isinstance(rmm_allocator_external_lib_list, str): - rmm_allocator_external_lib_list = [] + if rmm_allocator_external_lib_list is not None and not isinstance( + rmm_allocator_external_lib_list, list + ): + raise ValueError( + "rmm_allocator_external_lib_list must be a list of strings. " + "Valid examples: ['torch'], ['cupy'], or ['torch', 'cupy']. " + f"Received: {type(rmm_allocator_external_lib_list)} " + f"with value: {rmm_allocator_external_lib_list}" + ) + # Set nthreads=1 when parsing mem_limit since it only depends on n_workers logger = logging.getLogger(__name__) self.memory_limit = parse_memory_limit( diff --git a/dask_cuda/plugins.py b/dask_cuda/plugins.py index c2844278..cd1928af 100644 --- a/dask_cuda/plugins.py +++ b/dask_cuda/plugins.py @@ -1,13 +1,10 @@ import importlib import os +from typing import Callable, Dict from distributed import WorkerPlugin -from .utils import ( - enable_rmm_memory_for_library, - get_rmm_log_file_name, - parse_device_memory_limit, -) +from .utils import get_rmm_log_file_name, parse_device_memory_limit class CPUAffinity(WorkerPlugin): @@ -134,6 +131,66 @@ def setup(self, worker=None): enable_rmm_memory_for_library(lib) +def enable_rmm_memory_for_library(lib_name: str) -> None: + """Enable RMM memory pool support for a specified third-party library. + + This function allows the given library to utilize RMM's memory pool if it supports + integration with RMM. The library name is passed as a string argument, and if the + library is compatible, its memory allocator will be configured to use RMM. + + Parameters + ---------- + lib_name : str + The name of the third-party library to enable RMM memory pool support for. + Supported libraries are "cupy" and "torch". + + Raises + ------ + ValueError + If the library name is not supported or does not have RMM integration. + ImportError + If the required library is not installed. + """ + + # Mapping of supported libraries to their respective setup functions + setup_functions: Dict[str, Callable[[], None]] = { + "torch": _setup_rmm_for_torch, + "cupy": _setup_rmm_for_cupy, + } + + if lib_name not in setup_functions: + supported_libs = ", ".join(setup_functions.keys()) + raise ValueError( + f"The library '{lib_name}' is not supported for RMM integration. " + f"Supported libraries are: {supported_libs}." + ) + + # Call the setup function for the specified library + setup_functions[lib_name]() + + +def _setup_rmm_for_torch() -> None: + try: + import torch + except ImportError as e: + raise ImportError("PyTorch is not installed.") from e + + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + +def _setup_rmm_for_cupy() -> None: + try: + import cupy + except ImportError as e: + raise ImportError("CuPy is not installed.") from e + + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + class PreImport(WorkerPlugin): def __init__(self, libraries): if libraries is None: diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index e7d7cdbb..74596fe2 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -7,7 +7,7 @@ from contextlib import suppress from functools import singledispatch from multiprocessing import cpu_count -from typing import Callable, Dict, Optional +from typing import Optional import click import numpy as np @@ -767,66 +767,6 @@ def get_rmm_memory_resource_stack(mr) -> list: return None -def enable_rmm_memory_for_library(lib_name: str) -> None: - """ - Enable RMM memory pool support for a specified third-party library. - - This function allows the given library to utilize RMM's memory pool if it supports - integration with RMM. The library name is passed as a string argument, and if the - library is compatible, its memory allocator will be configured to use RMM. - - Parameters - ---------- - lib_name : str - The name of the third-party library to enable RMM memory pool support for. - - Raises - ------ - ValueError - If the library name is not supported or does not have RMM integration. - ImportError - If the required library is not installed. - """ - - # Mapping of supported libraries to their respective setup functions - setup_functions: Dict[str, Callable[[], None]] = { - "torch": _setup_rmm_for_torch, - "cupy": _setup_rmm_for_cupy, - } - - if lib_name not in setup_functions: - supported_libs = ", ".join(setup_functions.keys()) - raise ValueError( - f"The library '{lib_name}' is not supported for RMM integration. " - f"Supported libraries are: {supported_libs}." - ) - - # Call the setup function for the specified library - setup_functions[lib_name]() - - -def _setup_rmm_for_torch() -> None: - try: - import torch - except ImportError as e: - raise ImportError("PyTorch is not installed.") from e - - from rmm.allocators.torch import rmm_torch_allocator - - torch.cuda.memory.change_current_allocator(rmm_torch_allocator) - - -def _setup_rmm_for_cupy() -> None: - try: - import cupy - except ImportError as e: - raise ImportError("CuPy is not installed.") from e - - from rmm.allocators.cupy import rmm_cupy_allocator - - cupy.cuda.set_allocator(rmm_cupy_allocator) - - class CommaSeparatedChoice(click.Choice): def convert(self, value, param, ctx): values = [v.strip() for v in value.split(",")]