From c2d970b4af7b8ce1a2f2a2172f14e9b24fa60006 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:39:10 +0100 Subject: [PATCH 1/5] Initialize JAX configuration as per the release --- src/jaxsim/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index 82da0e3ac..c06d2c7cc 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -6,12 +6,12 @@ def _jnp_options() -> None: import os - from jax.config import config + import jax # Enable by default if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"): logging.info("Enabling JAX to use 64bit precision") - config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) import jax.numpy as jnp import numpy as np From 68ab8a2b264ed403b2fbbc0b87f38df1e89af368 Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Tue, 27 Feb 2024 14:15:10 +0100 Subject: [PATCH 2/5] Remove --forked from pytest as workaround for jax 0.4.25 --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3a74e3a49..3ec523dc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,10 @@ multi_line_output = 3 [tool.pytest.ini_options] minversion = "6.0" -addopts = "-rsxX -v --strict-markers --forked" +# --forked used to be passed here, but it was removed +# as workaround for compatibility with jax 0.4.25, +# see https://github.com/ami-iit/jaxsim/pull/92#issuecomment-1966290170 +addopts = "-rsxX -v --strict-markers" testpaths = [ "tests", ] From 48c8396f8391016bb45edb4548d04c34dce55c15 Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Tue, 27 Feb 2024 16:13:01 +0100 Subject: [PATCH 3/5] Constrain jax to be lower of 0.4.25 until we switch to functional --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 32bc02968..2204355bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ package_dir = python_requires = >=3.11 install_requires = coloredlogs - jax >= 0.4.13 + jax >= 0.4.13,<0.4.25 jaxlib jaxlie >= 1.3.0 jax_dataclasses >= 1.4.0 From 032dee589c72f222eae2c4843e647f5a57593d83 Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Tue, 27 Feb 2024 16:17:27 +0100 Subject: [PATCH 4/5] Add constraint also to jaxlib --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2204355bc..3dc535115 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ python_requires = >=3.11 install_requires = coloredlogs jax >= 0.4.13,<0.4.25 - jaxlib + jaxlib >= 0.4.13,<0.4.25 jaxlie >= 1.3.0 jax_dataclasses >= 1.4.0 pptree From 10d467779ba085fbfc35aeee2e4fcd7b0c8c564c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:51:53 +0100 Subject: [PATCH 5/5] Revert "Remove --forked from pytest as workaround for jax 0.4.25" --- pyproject.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ec523dc3..3a74e3a49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,7 @@ multi_line_output = 3 [tool.pytest.ini_options] minversion = "6.0" -# --forked used to be passed here, but it was removed -# as workaround for compatibility with jax 0.4.25, -# see https://github.com/ami-iit/jaxsim/pull/92#issuecomment-1966290170 -addopts = "-rsxX -v --strict-markers" +addopts = "-rsxX -v --strict-markers --forked" testpaths = [ "tests", ]