diff --git a/tests/op/test_fit_memory_estimate.py b/tests/op/test_fit_memory_estimate.py index 3ba9e21..df026c8 100644 --- a/tests/op/test_fit_memory_estimate.py +++ b/tests/op/test_fit_memory_estimate.py @@ -1,13 +1,12 @@ import pytest from sklearn.linear_model import LinearRegression -from crossfit.backend.torch.hf.memory_curve_utils import fit_memory_estimate_curve - transformers = pytest.importorskip("transformers") torch = pytest.importorskip("torch") -rmm_torch_allocator = pytest.importorskip( - "rmm.allocators.torch", reason="rmm_torch_allocator is not available." -).rmm_torch_allocator +rmm_torch_allocator = pytest.importorskip("rmm.allocators.torch").rmm_torch_allocator +fit_memory_estimate_curve = pytest.importorskip( + "crossfit.backend.torch.hf.memory_curve_utils" +).fit_memory_estimate_curve MODEL_NAME = "microsoft/deberta-v3-base"