diff --git a/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb b/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb index aba2647567d..3c90574ff5f 100644 --- a/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb +++ b/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb @@ -27,6 +27,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "77464844", "metadata": {}, @@ -53,6 +54,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "1c5598ae", "metadata": {}, @@ -92,6 +94,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "63ac0cf2", "metadata": {}, @@ -100,6 +103,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4def0005", "metadata": {}, @@ -123,6 +127,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d7c3f9ea", "metadata": {}, @@ -148,6 +153,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "728c3009", "metadata": {}, @@ -176,11 +182,15 @@ "\n", "# Deploy a Single-Machine Multi-GPU Cluster\n", "protocol = \"tcp\" # \"tcp\" or \"ucx\"\n", + "\n", "if numba.cuda.is_available():\n", " NUM_GPUS = list(range(len(numba.cuda.gpus)))\n", "else:\n", " NUM_GPUS = []\n", - "visible_devices = \",\".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers\n", + "try:\n", + " visible_devices = os.environ[\"CUDA_VISIBLE_DEVICES\"]\n", + "except KeyError:\n", + " visible_devices = \",\".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers\n", "device_limit_frac = 0.7 # Spill GPU-Worker memory to host at this limit.\n", "device_pool_frac = 0.8\n", "part_mem_frac = 0.15\n", @@ -206,6 +216,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d14dc098", "metadata": {}, @@ -242,6 +253,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0576affe", "metadata": {}, @@ -589,6 +601,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "94ef0024", "metadata": {}, @@ -599,6 +612,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "768fc24e", "metadata": {}, @@ -622,6 +636,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "61785127", "metadata": {}, @@ -678,6 +693,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "01ea40bb", "metadata": {}, @@ -686,6 +702,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "987f3274", "metadata": {}, @@ -714,6 +731,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b06c962e", "metadata": {}, @@ -745,6 +763,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d28ae761", "metadata": {}, @@ -755,6 +774,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4e07864d", "metadata": {}, @@ -763,6 +783,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "8f971a22", "metadata": {}, diff --git a/tests/conftest.py b/tests/conftest.py index 2598ced6408..0ca0ad87489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,10 +52,10 @@ def assert_eq(a, b, *args, **kwargs): import pytest from asvdb import ASVDb, BenchmarkInfo, utils -from dask.distributed import Client, LocalCluster from numba import cuda import nvtabular +from merlin.core.utils import Distributed from merlin.dag.node import iter_nodes REPO_ROOT = Path(__file__).parent.parent @@ -97,8 +97,9 @@ def assert_eq(a, b, *args, **kwargs): @pytest.fixture(scope="module") def client(): - cluster = LocalCluster(n_workers=2) - client = Client(cluster) + distributed = Distributed(n_workers=2) + cluster = distributed.cluster + client = distributed.client yield client client.close() cluster.close() diff --git a/tests/unit/framework_utils/test_tf_layers.py b/tests/unit/framework_utils/test_tf_layers.py index 106be0fa457..38e2778cab7 100644 --- a/tests/unit/framework_utils/test_tf_layers.py +++ b/tests/unit/framework_utils/test_tf_layers.py @@ -318,4 +318,4 @@ def test_multihot_empty_rows(): ) y_hat = model(x).numpy() - np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-06) + np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-05)