diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index 82da0e3ac..c06d2c7cc 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -6,12 +6,12 @@ def _jnp_options() -> None: import os - from jax.config import config + import jax # Enable by default if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"): logging.info("Enabling JAX to use 64bit precision") - config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) import jax.numpy as jnp import numpy as np