From 5bc01039ef989f985a21d2e595ad8ad0c3eafcc1 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 16:27:47 +0100 Subject: [PATCH] Use Distributed helper for client fixture in conftest.py (#1830) * Use Distributed helper for client fixture * reduce rtol in test_multihot_empty_rows --- tests/conftest.py | 7 ++++--- tests/unit/framework_utils/test_tf_layers.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2598ced6408..0ca0ad87489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,10 +52,10 @@ def assert_eq(a, b, *args, **kwargs): import pytest from asvdb import ASVDb, BenchmarkInfo, utils -from dask.distributed import Client, LocalCluster from numba import cuda import nvtabular +from merlin.core.utils import Distributed from merlin.dag.node import iter_nodes REPO_ROOT = Path(__file__).parent.parent @@ -97,8 +97,9 @@ def assert_eq(a, b, *args, **kwargs): @pytest.fixture(scope="module") def client(): - cluster = LocalCluster(n_workers=2) - client = Client(cluster) + distributed = Distributed(n_workers=2) + cluster = distributed.cluster + client = distributed.client yield client client.close() cluster.close() diff --git a/tests/unit/framework_utils/test_tf_layers.py b/tests/unit/framework_utils/test_tf_layers.py index 106be0fa457..38e2778cab7 100644 --- a/tests/unit/framework_utils/test_tf_layers.py +++ b/tests/unit/framework_utils/test_tf_layers.py @@ -318,4 +318,4 @@ def test_multihot_empty_rows(): ) y_hat = model(x).numpy() - np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-06) + np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-05)