diff --git a/tests/conftest.py b/tests/conftest.py index 2598ced6408..0ca0ad87489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,10 +52,10 @@ def assert_eq(a, b, *args, **kwargs): import pytest from asvdb import ASVDb, BenchmarkInfo, utils -from dask.distributed import Client, LocalCluster from numba import cuda import nvtabular +from merlin.core.utils import Distributed from merlin.dag.node import iter_nodes REPO_ROOT = Path(__file__).parent.parent @@ -97,8 +97,9 @@ def assert_eq(a, b, *args, **kwargs): @pytest.fixture(scope="module") def client(): - cluster = LocalCluster(n_workers=2) - client = Client(cluster) + distributed = Distributed(n_workers=2) + cluster = distributed.cluster + client = distributed.client yield client client.close() cluster.close() diff --git a/tests/unit/framework_utils/test_tf_layers.py b/tests/unit/framework_utils/test_tf_layers.py index 106be0fa457..38e2778cab7 100644 --- a/tests/unit/framework_utils/test_tf_layers.py +++ b/tests/unit/framework_utils/test_tf_layers.py @@ -318,4 +318,4 @@ def test_multihot_empty_rows(): ) y_hat = model(x).numpy() - np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-06) + np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-05)