diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77995cb5a74..bcb11aa1f79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: exclude: ^(packages/grid/ansible/) - id: name-tests-test always_run: true - exclude: ^(packages/grid/backend/grid/tests/utils/)|^(.*fixtures.py) + exclude: ^(packages/grid/backend/grid/tests/utils/)|^(.*fixtures.py)|^packages/syft/tests/.*/utils.py - id: requirements-txt-fixer always_run: true - id: mixed-line-ending diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index a0a42f6accb..7cdc5d73232 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -8,6 +8,9 @@ from syft.service.response import SyftError from syft.types.uid import LineageID +# relative +from .utils import currently_fail_on_python_3_12 + def test_actionobject_method(worker): root_domain_client = worker.root_client @@ -20,6 +23,7 @@ def test_actionobject_method(worker): assert res[0] == "A" +@currently_fail_on_python_3_12(raises=AttributeError) def test_lib_function_action(worker): root_domain_client = worker.root_client numpy_client = root_domain_client.api.lib.numpy diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py index 68ef65404cd..7f34e80430d 100644 --- a/packages/syft/tests/syft/eager_test.py +++ b/packages/syft/tests/syft/eager_test.py @@ -1,16 +1,13 @@ -# stdlib -import sys - # third party import numpy as np -import pytest # syft absolute from syft.service.action.action_object import ActionObject from syft.service.action.plan import planify from syft.types.twin_object import TwinObject -PYTHON_AT_LEAST_3_12 = sys.version_info >= (3, 12) +# relative +from .utils import currently_fail_on_python_3_12 def test_eager_permissions(worker, guest_client): @@ -76,11 +73,7 @@ def my_plan(x=np.array([[2, 2, 2], [2, 2, 2]])): # noqa: B008 assert res_ptr.get_from(guest_client) == 729 -@pytest.mark.xfail( - PYTHON_AT_LEAST_3_12, - raises=AttributeError, - reason="Does not work yet on Python>=3.12 and numpy>=1.26", -) +@currently_fail_on_python_3_12(raises=AttributeError) def test_plan_with_function_call(worker, guest_client): root_domain_client = worker.root_client guest_client = worker.guest_client diff --git a/packages/syft/tests/syft/serde/__init__.py b/packages/syft/tests/syft/serde/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/tests/syft/serde/numpy_functions_test.py b/packages/syft/tests/syft/serde/numpy_functions_test.py index c6a503b9d48..4b759322e23 100644 --- a/packages/syft/tests/syft/serde/numpy_functions_test.py +++ b/packages/syft/tests/syft/serde/numpy_functions_test.py @@ -1,6 +1,3 @@ -# stdlib -import sys - # third party import numpy as np import pytest @@ -9,12 +6,14 @@ from syft import ActionObject from syft.service.response import SyftAttributeError +# relative +from ..utils import PYTHON_AT_LEAST_3_12 +from ..utils import currently_fail_on_python_3_12 + PYTHON_ARRAY = [0, 1, 1, 2, 2, 3] NP_ARRAY = np.array([0, 1, 1, 5, 5, 3]) NP_2dARRAY = np.array([[3, 4, 5, 2], [6, 7, 2, 6]]) -PYTHON_AT_LEAST_3_12 = sys.version_info >= (3, 12) - NOT_WORK_YET_ON_NUMPY_1_26_PYTHON_3_12: list[tuple[str, str]] = [ ("linspace", "10,10,10"), ("logspace", "0,2"), @@ -85,10 +84,7 @@ pytest.param( func, func_arguments, - marks=pytest.mark.xfail( - PYTHON_AT_LEAST_3_12, - reason="Does not work yet on Python>=3.12 and numpy>=1.26", - ), + marks=currently_fail_on_python_3_12(), ) for func, func_arguments in NOT_WORK_YET_ON_NUMPY_1_26_PYTHON_3_12 ], diff --git a/packages/syft/tests/syft/utils.py b/packages/syft/tests/syft/utils.py new file mode 100644 index 00000000000..668c551efbc --- /dev/null +++ b/packages/syft/tests/syft/utils.py @@ -0,0 +1,15 @@ +# stdlib +from functools import partial +import sys + +# third party +import pytest + +PYTHON_AT_LEAST_3_12 = sys.version_info >= (3, 12) +FAIL_ON_PYTHON_3_12_REASON = "Does not work yet on Python>=3.12 and numpy>=1.26" + +currently_fail_on_python_3_12 = partial( + pytest.mark.xfail, + PYTHON_AT_LEAST_3_12, + reason=FAIL_ON_PYTHON_3_12_REASON, +)