From acedffddcdca92ff20eca63372c5f10c1a05511c Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:38:12 +0200 Subject: [PATCH 01/26] Update notation for support body array --- src/jaxsim/api/link.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 27c352db9..ba0e0ee91 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -241,8 +241,8 @@ def jacobian( ) # Compute the actual doubly-left free-floating jacobian of the link. - κ = model.kin_dyn_parameters.support_body_array_bool[link_index] - B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B + κb = model.kin_dyn_parameters.support_body_array_bool[link_index] + B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B # Adjust the input representation such that `J_WL_I @ I_ν`. match data.velocity_representation: From 8952d908d9c65c5b16f00cde6b9938a135181a95 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:39:51 +0200 Subject: [PATCH 02/26] Minor fixes of jaxsim.api.model --- src/jaxsim/api/model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 303339e55..a65575d58 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -33,6 +33,7 @@ class JaxSimModel(JaxsimDataclass): terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False ) + kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = ( dataclasses.field(default=None, repr=False, compare=False, hash=False) ) @@ -302,7 +303,7 @@ def reduce( locked_joint_positions: A dictionary containing the positions of the joints to be considered in the reduction process. The removed joints in the reduced model - will have their position locked to their value in this dictionary. + will have their position locked to their value of this dictionary. If a joint is not part of the dictionary, its position is set to zero. """ @@ -1483,12 +1484,7 @@ def link_bias_accelerations( # ================================================ # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.from_quaternion_xyzw( - xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation()) - ), - translation=data.base_position(), - ).as_matrix() + W_H_B = data.base_transform() def other_representation_to_inertial( C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector @@ -1529,9 +1525,12 @@ def other_representation_to_inertial( W_H_C = W_H_BW with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) + W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW with data.switch_velocity_representation(VelRepr.Mixed): C_v_WB = BW_v_WB = data.base_velocity() + case _: raise ValueError(data.velocity_representation) From 439d6a9ff30941eab1abf73bde7df23867c977a7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:47:33 +0200 Subject: [PATCH 03/26] Expose the Baumgarte regularization coefficient --- src/jaxsim/api/ode.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 426934f16..5d6463c85 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -223,7 +223,9 @@ def system_velocity_dynamics( @jax.jit def system_position_dynamics( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: """ Compute the dynamics of the system position. @@ -231,6 +233,8 @@ def system_position_dynamics( Args: model: The model to consider. data: The data of the considered model. + baumgarte_quaternion_regularization: + The Baumgarte regularization coefficient for adjusting the quaternion norm. Returns: A tuple containing the derivative of the base position, the derivative of the @@ -250,6 +254,7 @@ def system_position_dynamics( quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False, + K=baumgarte_quaternion_regularization, ).squeeze() return W_ṗ_B, W_Q̇_B, ṡ @@ -262,6 +267,7 @@ def system_dynamics( *, joint_forces: jtp.Vector | None = None, link_forces: jtp.Vector | None = None, + baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[ODEState, dict[str, Any]]: """ Compute the dynamics of the system. @@ -271,6 +277,9 @@ def system_dynamics( data: The data of the considered model. joint_forces: The joint forces to apply. link_forces: The 6D forces to apply to the links. + baumgarte_quaternion_regularization: + The Baumgarte regularization coefficient used to adjust the norm of the + quaternion (only used in integrators not operating on the SO(3) manifold). Returns: A tuple with an `ODEState` object storing in each of its attributes the @@ -287,7 +296,11 @@ def system_dynamics( ) # Extract the velocities. - W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data) + W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( + model=model, + data=data, + baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, + ) # Create an ODEState object populated with the derivative of each leaf. # Our integrators, operating on generic pytrees, will be able to handle it From 9d98def644ac54f7ea590aa7ef24b99303d9fafb Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:49:27 +0200 Subject: [PATCH 04/26] Simplify condition to check JAX_ENABLE_X64 env var --- src/jaxsim/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index ee588a8fa..d9cf2c2e0 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -8,8 +8,9 @@ def _jnp_options() -> None: import jax - # Enable by default - if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"): + # Enable by default 64bit precision in JAX. + if os.environ.get("JAX_ENABLE_X64", "1") != "0": + logging.info("Enabling JAX to use 64bit precision") jax.config.update("jax_enable_x64", True) From 0e55a1aff710960b3425d5dba10e599bb533c72d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:49:57 +0200 Subject: [PATCH 05/26] Add jaxsim.typing.Scalar and jaxsim.typing.ScalarLike --- src/jaxsim/typing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 3f8aadfa8..6d4f66d4b 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -24,6 +24,7 @@ # ======================= Array = jax.typing.ArrayLike +Scalar = Array Vector = Array Matrix = Array @@ -31,6 +32,7 @@ Bool = bool | ArrayJax Float = float | FloatJax +ScalarLike = Scalar | int | float ArrayLike = Array VectorLike = Vector MatrixLike = Matrix From 19552f0ecf78214f51efb705d6c22bb6ce92de28 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:50:28 +0200 Subject: [PATCH 06/26] Update condition to exclude body-fixed representation in CoM test --- tests/test_api_com.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 1598b81c7..559689401 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -53,11 +53,7 @@ def test_com_properties( assert pytest.approx(v_avg_com_idt) == v_avg_com_js # https://github.com/ami-iit/jaxsim/pull/117#discussion_r1535486123 - with data.switch_velocity_representation( - data.velocity_representation - if data.velocity_representation is not VelRepr.Body - else VelRepr.Mixed - ): + if data.velocity_representation is not VelRepr.Body: vl_com_idt = kin_dyn.com_velocity() vl_com_js = js.com.com_linear_velocity(model=model, data=data) assert pytest.approx(vl_com_idt) == vl_com_js From 5e2897bc3811b57454dfbcf8472dbe8b848ef15f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 10:52:03 +0200 Subject: [PATCH 07/26] Update link pose check of reduced models --- tests/test_api_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 77e1107d4..5524a6b8f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -135,18 +135,19 @@ def test_model_creation_and_reduction( ) # Check that link transforms match. - for link_name, link_idx in zip( - model_reduced.link_names(), - js.link.names_to_idxs( - model=model_reduced, link_names=model_reduced.link_names() - ), - ): + for link_name in model_reduced.link_names(): + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( kin_dyn_full.frame_transform(frame_name=link_name) ) + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( js.link.transform( - model=model_reduced, data=data_reduced, link_index=link_idx + model=model_reduced, + data=data_reduced, + link_index=js.link.name_to_idx( + model=model_reduced, link_name=link_name + ), ) ) From c4e5d33bbdfcb12077e920d9e28f44aeaf6e08e9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 11:53:26 +0200 Subject: [PATCH 08/26] Add frame pose check of reduced models --- tests/test_api_model.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 5524a6b8f..5a15f1fd0 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -139,7 +139,7 @@ def test_model_creation_and_reduction( assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( kin_dyn_full.frame_transform(frame_name=link_name) - ) + ), link_name assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( js.link.transform( @@ -149,7 +149,31 @@ def test_model_creation_and_reduction( model=model_reduced, link_name=link_name ), ) - ) + ), link_name + + # Check that frame transforms match. + for frame_name in model_reduced.frame_names(): + + if frame_name not in kin_dyn_reduced.frame_names(): + continue + + # Skip some entry of models with many frames. + if "skin" in frame_name or "laser" in frame_name or "depth" in frame_name: + continue + + assert kin_dyn_reduced.frame_transform(frame_name=frame_name) == pytest.approx( + kin_dyn_full.frame_transform(frame_name=frame_name) + ), frame_name + + assert kin_dyn_reduced.frame_transform(frame_name=frame_name) == pytest.approx( + js.frame.transform( + model=model_reduced, + data=data_reduced, + frame_index=js.frame.name_to_idx( + model=model_reduced, frame_name=frame_name + ), + ) + ), frame_name def test_model_properties( From f383d8b3f0d2aefc51153f33c74b21ee76bbe304 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 11:00:40 +0200 Subject: [PATCH 09/26] Fix EOFs and trailing whitespaces --- .gitattributes | 1 - README.md | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.gitattributes b/.gitattributes index 16ef5c5f7..d5799bd69 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,2 @@ # GitHub syntax highlighting pixi.lock linguist-language=YAML - diff --git a/README.md b/README.md index 682767163..4454b239b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX. -Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence. +Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence. ## Features @@ -25,7 +25,7 @@ Its design facilitates research and accelerates prototyping in the intersection ### JaxSim as a multibody dynamics library -- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians. +- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians. - Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion. - Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation]. - Exposes all the necessary quantities to develop controllers in centroidal coordinates. @@ -132,10 +132,10 @@ The main differences between MJX/Brax and JaxSim are as follows: - JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS]. - JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface. - Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous + Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous function of the state $(\mathbf{q}, \boldsymbol{\nu})$, it doesn't require running any optimization algorithm when stepping the simulation forward. -- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators. +- JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators. [brax]: https://github.com/google/brax [mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html From 3564357b177e577db4932814c11a8c11030aa71e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 11:18:20 +0200 Subject: [PATCH 10/26] Allow overriding default logging verbosity --- src/jaxsim/__init__.py | 41 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index d9cf2c2e0..4f15e0ea5 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -28,6 +28,7 @@ def _np_options() -> None: def _is_editable() -> bool: + import importlib.util import pathlib import site @@ -46,11 +47,40 @@ def _is_editable() -> bool: return jaxsim_package_dir not in site.getsitepackages() -# Initialize the logging verbosity -if _is_editable(): - logging.configure(level=logging.LoggingLevel.DEBUG) -else: - logging.configure(level=logging.LoggingLevel.WARNING) +def _get_default_logging_level(env_var: str) -> logging.LoggingLevel: + """ + Get the default logging level. + + Args: + env_var: The environment variable to check. + + Returns: + The logging level to set. + """ + + import os + + # Define the default logging level depending on the installation mode. + default_logging_level = ( + logging.LoggingLevel.DEBUG + if _is_editable() # noqa: F821 + else logging.LoggingLevel.WARNING + ) + + # Allow to override the default logging level with an environment variable. + try: + return logging.LoggingLevel[ + os.environ.get(env_var, default_logging_level.name).upper() + ] + + except KeyError as exc: + msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'" + raise RuntimeError(msg) from exc + + +# Configure the logger with the default logging level. +logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL")) + # Configure JAX _jnp_options() @@ -60,6 +90,7 @@ def _is_editable() -> bool: del _jnp_options del _np_options +del _get_default_logging_level del _is_editable from . import terrain # isort:skip From 4d62b6f231f7e4ce2c0760a5223dbaae6bdaa44e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 11:54:00 +0200 Subject: [PATCH 11/26] Extend model reduction test --- tests/test_api_model.py | 45 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 5a15f1fd0..c89a700af 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -94,7 +94,7 @@ def test_model_creation_and_reduction( # Check that all non-fixed joints are in the reduced model. assert set(reduced_joints) == set(model_reduced.joint_names()) - # Check that the reduce model maintain the same terrain of the full model. + # Check that the reduced model maintains the same terrain of the full model. assert model_full.terrain == model_reduced.terrain # Build the data of the reduced model. @@ -113,6 +113,49 @@ def test_model_creation_and_reduction( velocity_representation=data_full.velocity_representation, ) + # Check that the reduced model data is valid. + assert not data_reduced.valid(model=model_full) + assert data_reduced.valid(model=model_reduced) + + # Check that the total mass is preserved. + assert js.model.total_mass(model=model_full) == pytest.approx( + js.model.total_mass(model=model_reduced) + ) + + # Check that the CoM position is preserved. + assert js.com.com_position(model=model_full, data=data_full) == pytest.approx( + js.com.com_position(model=model_reduced, data=data_reduced), abs=1e-6 + ) + + # Check that joint serialization works. + assert data_full.joint_positions( + model=model_full, joint_names=model_reduced.joint_names() + ) == pytest.approx(data_reduced.joint_positions()) + assert data_full.joint_velocities( + model=model_full, joint_names=model_reduced.joint_names() + ) == pytest.approx(data_reduced.joint_velocities()) + + # Check that link transforms are preserved. + for link_name in model_reduced.link_names(): + W_H_L_full = js.link.transform( + model=model_full, + data=data_full, + link_index=js.link.name_to_idx(model=model_full, link_name=link_name), + ) + W_H_L_reduced = js.link.transform( + model=model_reduced, + data=data_reduced, + link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name), + ) + assert W_H_L_full == pytest.approx(W_H_L_reduced) + + # Check that collidable point positions are preserved. + assert js.contact.collidable_point_positions( + model=model_full, data=data_full + ) == pytest.approx( + js.contact.collidable_point_positions(model=model_reduced, data=data_reduced) + ) + # ===================== # Test against iDynTree # ===================== From b201859f8e98cb2b1ea3350e54c96d071408414d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:25:16 +0200 Subject: [PATCH 12/26] Remove unused variables --- src/jaxsim/mujoco/loaders.py | 2 +- src/jaxsim/mujoco/model.py | 2 +- src/jaxsim/rbda/crba.py | 4 ++-- src/jaxsim/rbda/forward_kinematics.py | 2 +- tests/test_automatic_differentiation.py | 6 +++--- tests/test_simulations.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index eef5bcaf9..6660da587 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -352,7 +352,7 @@ def convert( # Set alpha=0 to the color of all collision elements for geometry_element in mujoco_element.findall(".//geom[@rgba]"): if geometry_element.attrib.get("name") in collision_names: - r, g, b, a = geometry_element.attrib["rgba"].split(" ") + r, g, b, _ = geometry_element.attrib["rgba"].split(" ") geometry_element.set("rgba", f"{r} {g} {b} 0") # ----------------------- diff --git a/src/jaxsim/mujoco/model.py b/src/jaxsim/mujoco/model.py index 62d2912f8..b4f677a3e 100644 --- a/src/jaxsim/mujoco/model.py +++ b/src/jaxsim/mujoco/model.py @@ -73,7 +73,7 @@ def build_from_xml( new_hfield = generate_hfield(heightmap, (nrow, ncol)) model.hfield_data = new_hfield - return MujocoModelHelper(model=model, data=mj.MjData(model)) + return MujocoModelHelper(model=model, data=data) def time(self) -> float: """Return the simulation time.""" diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 18fe8e30f..27ee83042 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -111,7 +111,7 @@ def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn: # a while loop using a for loop with fixed number of iterations. def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]: def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: - j, Fi, M = carry + j, _, _ = carry out = jax.lax.cond( pred=(λ[j] > 0), true_fun=while_loop_body, @@ -120,7 +120,7 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: ) return out, None - j, Fi, M = carry + j, _, _ = carry return jax.lax.cond( pred=(k == j), true_fun=compute_inner, diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 55a81d390..8bcab038a 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -49,7 +49,7 @@ def forward_kinematics_model( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( joint_positions=s, base_transform=W_H_B.as_matrix() ) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 31b2ae07f..cdb408c59 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -179,7 +179,7 @@ def test_ad_crba( model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) - data, references = get_random_data_and_references( + data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) @@ -211,7 +211,7 @@ def test_ad_fk( model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) - data, references = get_random_data_and_references( + data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) @@ -250,7 +250,7 @@ def test_ad_jacobian( model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) - data, references = get_random_data_and_references( + data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 0bd96bf4e..4968a177f 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -102,7 +102,7 @@ def test_box_with_zero_gravity( model = jaxsim_model_box # Split the PRNG key. - key, subkey, subkey2 = jax.random.split(prng_key, num=3) + _, subkey, subkey2 = jax.random.split(prng_key, num=3) # Build the data of the model. data0 = js.data.JaxSimModelData.build( From a19f18e60fab680786d440fcd583999343d985cd Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:25:38 +0200 Subject: [PATCH 13/26] Avoid redefining existing functions --- src/jaxsim/api/contact.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 4d66d240f..92ceee83b 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -365,20 +365,20 @@ def jacobian( W_H_C = transforms(model=model, data=data) - def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: C_X_W = jaxsim.math.Adjoint.from_transform( transform=W_H_C, inverse=True ) C_J_WC = C_X_W @ W_J_WC return C_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) case VelRepr.Mixed: W_H_C = transforms(model=model, data=data) - def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) @@ -389,7 +389,7 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: CW_J_WC = CW_X_W @ W_J_WC return CW_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) case _: raise ValueError(output_vel_repr) From e25ec0a99e4d4a6c553b72375f749a51cf34a74a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:26:23 +0200 Subject: [PATCH 14/26] Remove double import --- src/jaxsim/api/model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a65575d58..656732bce 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -687,8 +687,6 @@ def to_active( another representation C_v̇_WB expressed in a generic frame C. """ - from jaxsim.math import Cross - # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() From 741cf67dcffa8e9a3135007bfc4031ee29fa0937 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:26:48 +0200 Subject: [PATCH 15/26] Use set comprehensions --- src/jaxsim/mujoco/loaders.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index 6660da587..d2af6bd0b 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -188,10 +188,9 @@ def convert( ) # If considered joints are passed, make sure that they are all part of the model. - if considered_joints - set([j.name for j in rod_model.joints()]): - extra_joints = set(considered_joints) - set( - [j.name for j in rod_model.joints()] - ) + if considered_joints - {j.name for j in rod_model.joints()}: + extra_joints = set(considered_joints) - {j.name for j in rod_model.joints()} + msg = f"Couldn't find the following joints in the model: '{extra_joints}'" raise ValueError(msg) From 1063e2efaadb60d6f4033f4a07dc208f03222c55 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:27:28 +0200 Subject: [PATCH 16/26] Update deprecated `typing.Hashable` for Python 3.12 --- src/jaxsim/typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 6d4f66d4b..5d56467c0 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -1,4 +1,5 @@ -from typing import Any, Hashable +from collections.abc import Hashable +from typing import Any import jax From aefe446a1d494c3b54b3c6bb0f9819f923406967 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:28:01 +0200 Subject: [PATCH 17/26] Avoid non-assigned expressions --- src/jaxsim/mujoco/visualizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py index 9af33631a..810391f36 100644 --- a/src/jaxsim/mujoco/visualizer.py +++ b/src/jaxsim/mujoco/visualizer.py @@ -173,4 +173,4 @@ def open( try: yield handle finally: - handle.close() if close_on_exit else None + _ = handle.close() if close_on_exit else None From 24e86edd6ab0d00824d5a4a292cf5d423a21ce6d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:41:14 +0200 Subject: [PATCH 18/26] Update typing --- src/jaxsim/math/joint_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index a641fa7c0..e770cb42a 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -34,8 +34,8 @@ class JointModel: already in a vectorized form. In other words, it cannot be created using vmap. """ - λ_H_pre: jax.Array - suc_H_i: jax.Array + λ_H_pre: jtp.Array + suc_H_i: jtp.Array joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] From 15ab7c5dec34b1a1d8ffe23e8fd841d7626d5713 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 12:41:32 +0200 Subject: [PATCH 19/26] Add `KinematicGraph.joints_removed` property --- src/jaxsim/parsers/descriptions/model.py | 6 +++--- src/jaxsim/parsers/kinematic_graph.py | 11 +++++++++++ tests/utils_idyntree.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 6cd264cfa..d3ba2de62 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -145,7 +145,7 @@ def build_model_from( root=kinematic_graph.root, joints=kinematic_graph.joints, frames=kinematic_graph.frames, - _joints_removed=kinematic_graph._joints_removed, + _joints_removed=kinematic_graph.joints_removed, ) # Check that the root link of kinematic graph is the desired base link. @@ -182,8 +182,8 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: ) # Include the unconnected/removed joints from the original model. - for joint in self._joints_removed: - reduced_model_description._joints_removed.append(joint) + for joint in self.joints_removed: + reduced_model_description.joints_removed.append(joint) return reduced_model_description diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 83083c781..fa5ff07ef 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -615,6 +615,17 @@ def print_tree(self) -> None: horizontal=True, ) + @property + def joints_removed(self) -> list[descriptions.JointDescription]: + """ + Get the list of joints removed during the graph reduction. + + Returns: + The list of removed joints. + """ + + return self._joints_removed + @staticmethod def breadth_first_search( root: descriptions.LinkDescription, diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 61cbe4a2b..7040fcab7 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -52,7 +52,7 @@ def build_kindyncomputations_from_jaxsim_model( # Get the default positions already stored in the model description. removed_joint_positions_default = { str(j.name): float(j.initial_position) - for j in model.description._joints_removed + for j in model.description.joints_removed if j.name not in considered_joints } From baf07bc30ea7f5e23d82eeeb856e47c62bf3e3d5 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 11 Jun 2024 13:11:33 +0200 Subject: [PATCH 20/26] Install only "gz sdf" command instead of full Gazebo Sim in CI --- .github/workflows/ci_cd.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index aeffddf8a..a5f7e3e78 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -118,8 +118,7 @@ jobs: - *src - *tests - # https://gazebosim.org/docs/harmonic/install_ubuntu - - name: Install Gazebo Sim + - name: Install 'gz sdf' system command if: | contains(matrix.os, 'ubuntu') && (github.event_name != 'pull_request' || @@ -130,7 +129,7 @@ jobs: sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null sudo apt-get update - sudo apt-get install gz-harmonic + sudo apt-get install --no-install-recommends libsdformat13 gz-tools2 - name: Run the Python tests if: | From b7be8ccfa6bdeb5b6481332561950ccc7290fdf1 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 11 Jun 2024 18:05:36 +0200 Subject: [PATCH 21/26] Update `.gitignore` --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 3c4f1c472..7bc6fee25 100644 --- a/.gitignore +++ b/.gitignore @@ -140,5 +140,10 @@ src/jaxsim/_version.py # ruff .ruff_cache/ + # pixi environments .pixi + +# data +.mp4 +.png From e9a3f8329179dcf066e93e46c8fcbcbb0ad48cf9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 12 Jun 2024 15:30:31 +0200 Subject: [PATCH 22/26] Make jaxsim.description classes hashable --- src/jaxsim/parsers/descriptions/collision.py | 81 ++++++++++++++------ src/jaxsim/parsers/descriptions/joint.py | 56 ++++++++++---- src/jaxsim/parsers/descriptions/link.py | 15 ++-- src/jaxsim/parsers/descriptions/model.py | 36 +++++++-- src/jaxsim/parsers/kinematic_graph.py | 36 ++++++--- 5 files changed, 159 insertions(+), 65 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 289ab70be..31ae17b97 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import abc import dataclasses -from typing import List import jax.numpy as jnp import numpy as np import numpy.typing as npt +import jaxsim.typing as jtp from jaxsim import logging from .link import LinkDescription @@ -17,9 +19,9 @@ class CollidablePoint: Represents a collidable point associated with a parent link. Attributes: - parent_link (LinkDescription): The parent link to which the collidable point is attached. - position (npt.NDArray): The position of the collidable point relative to the parent link. - enabled (bool): A flag indicating whether the collidable point is enabled for collision detection. + parent_link: The parent link to which the collidable point is attached. + position: The position of the collidable point relative to the parent link. + enabled: A flag indicating whether the collidable point is enabled for collision detection. """ @@ -29,7 +31,7 @@ class CollidablePoint: def change_link( self, new_link: LinkDescription, new_H_old: npt.NDArray - ) -> "CollidablePoint": + ) -> CollidablePoint: """ Move the collidable point to a new parent link. @@ -39,8 +41,8 @@ def change_link( Returns: CollidablePoint: A new collidable point associated with the new parent link. - """ + msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}" logging.debug(msg=msg) @@ -50,15 +52,24 @@ def change_link( enabled=self.enabled, ) - def __eq__(self, other): - retval = ( - self.parent_link == other.parent_link - and (self.position == other.position).all() - and self.enabled == other.enabled + def __hash__(self) -> int: + + return hash( + ( + hash(self.parent_link), + hash(tuple(self.position.tolist())), + hash(self.enabled), + ) ) - return retval - def __str__(self): + def __eq__(self, other: CollidablePoint) -> bool: + + if not isinstance(other, CollidablePoint): + return False + + return hash(self) == hash(other) + + def __str__(self) -> str: return ( f"{self.__class__.__name__}(" + f"parent_link={self.parent_link.name}" @@ -74,11 +85,11 @@ class CollisionShape(abc.ABC): Abstract base class for representing collision shapes. Attributes: - collidable_points (List[CollidablePoint]): A list of collidable points associated with the collision shape. + collidable_points: A list of collidable points associated with the collision shape. """ - collidable_points: List[CollidablePoint] + collidable_points: tuple[CollidablePoint] def __str__(self): return ( @@ -95,14 +106,26 @@ class BoxCollision(CollisionShape): Represents a box-shaped collision shape. Attributes: - center (npt.NDArray): The center of the box in the local frame of the collision shape. + center: The center of the box in the local frame of the collision shape. """ - center: npt.NDArray + center: jtp.VectorLike - def __eq__(self, other): - return (self.center == other.center).all() and super().__eq__(other) + def __hash__(self) -> int: + return hash( + ( + hash(super()), + hash(tuple(self.center.tolist())), + ) + ) + + def __eq__(self, other: BoxCollision) -> bool: + + if not isinstance(other, BoxCollision): + return False + + return hash(self) == hash(other) @dataclasses.dataclass @@ -111,11 +134,23 @@ class SphereCollision(CollisionShape): Represents a spherical collision shape. Attributes: - center (npt.NDArray): The center of the sphere in the local frame of the collision shape. + center: The center of the sphere in the local frame of the collision shape. """ - center: npt.NDArray + center: jtp.VectorLike + + def __hash__(self) -> int: + return hash( + ( + hash(super()), + hash(tuple(self.center.tolist())), + ) + ) + + def __eq__(self, other: BoxCollision) -> bool: + + if not isinstance(other, BoxCollision): + return False - def __eq__(self, other): - return (self.center == other.center).all() and super().__eq__(other) + return hash(self) == hash(other) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index a2139e0e6..1aa4eef84 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,11 +1,10 @@ from __future__ import annotations import dataclasses -from typing import ClassVar, Tuple, Union +from typing import ClassVar import jax_dataclasses import numpy as np -import numpy.typing as npt import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass, Mutability @@ -15,6 +14,7 @@ @dataclasses.dataclass(frozen=True) class JointType: + Fixed: ClassVar[int] = 0 Revolute: ClassVar[int] = 1 Prismatic: ClassVar[int] = 2 @@ -64,29 +64,31 @@ class JointDescription(JaxsimDataclass): """ name: jax_dataclasses.Static[str] - axis: npt.NDArray - pose: npt.NDArray - jtype: jax_dataclasses.Static[JointType] + axis: jtp.Vector + pose: jtp.Matrix + jtype: jax_dataclasses.Static[jtp.IntLike] child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) - index: int | None = None + index: jtp.IntLike | None = None + + friction_static: jtp.FloatLike = 0.0 + friction_viscous: jtp.FloatLike = 0.0 - friction_static: float = 0.0 - friction_viscous: float = 0.0 + position_limit_damper: jtp.FloatLike = 0.0 + position_limit_spring: jtp.FloatLike = 0.0 - position_limit_damper: float = 0.0 - position_limit_spring: float = 0.0 + position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0) + initial_position: jtp.FloatLike | jtp.VectorLike = 0.0 - position_limit: Tuple[float, float] = (0.0, 0.0) - initial_position: Union[float, npt.NDArray] = 0.0 + motor_inertia: jtp.FloatLike = 0.0 + motor_viscous_friction: jtp.FloatLike = 0.0 + motor_gear_ratio: jtp.FloatLike = 1.0 - motor_inertia: float = 0.0 - motor_viscous_friction: float = 0.0 - motor_gear_ratio: float = 1.0 + def __post_init__(self) -> None: - def __post_init__(self): if self.axis is not None: + with self.mutable_context( mutability=Mutability.MUTABLE, restore_after_exception=False ): @@ -94,4 +96,24 @@ def __post_init__(self): self.axis = self.axis / norm_of_axis def __hash__(self) -> int: - return hash(self.__repr__()) + + return hash( + ( + hash(self.name), + hash(tuple(self.axis.tolist())), + hash(tuple(self.pose.flatten().tolist())), + hash(int(self.jtype)), + hash(self.child), + hash(self.parent), + hash(int(self.index)) if self.index is not None else 0, + hash(float(self.friction_static)), + hash(float(self.friction_viscous)), + hash(float(self.position_limit_damper)), + hash(float(self.position_limit_spring)), + hash((float(el) for el in self.position_limit)), + hash(tuple(np.atleast_1d(self.initial_position).tolist())), + hash(float(self.motor_inertia)), + hash(float(self.motor_viscous_friction)), + hash(float(self.motor_gear_ratio)), + ), + ) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index ec9129477..859aa7122 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax_dataclasses import jaxlie +import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp @@ -23,7 +24,7 @@ class LinkDescription(JaxsimDataclass): index: An optional index for the link (it gets automatically assigned). parent: The parent link of this link. pose: The pose transformation matrix of the link. - children: List of child links. + children: The children links. """ name: Static[str] @@ -33,7 +34,7 @@ class LinkDescription(JaxsimDataclass): parent: LinkDescription = dataclasses.field(default=None, repr=False) pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False) - children: Static[list[LinkDescription]] = dataclasses.field( + children: Static[tuple[LinkDescription]] = dataclasses.field( default_factory=list, repr=False ) @@ -43,10 +44,12 @@ def __hash__(self) -> int: ( hash(self.name), hash(float(self.mass)), - hash(tuple(self.inertia.flatten().tolist())), - hash(int(self.index)), - hash(self.parent), - hash(tuple(hash(c) for c in self.children)), + hash(tuple(np.atleast_1d(self.inertia).flatten().tolist())), + hash(int(self.index)) if self.index is not None else 0, + hash(tuple(np.atleast_1d(self.pose).flatten().tolist())), + hash(tuple(self.children)), + # Here only using the name to prevent circular recursion: + hash(self.parent.name) if self.parent is not None else 0, ) ) diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index d3ba2de62..14f294ac2 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -27,7 +27,7 @@ class ModelDescription(KinematicGraph): fixed_base: bool = True - collision_shapes: list[CollisionShape] = dataclasses.field( + collision_shapes: tuple[CollisionShape, ...] = dataclasses.field( default_factory=list, repr=False, hash=False ) @@ -37,7 +37,7 @@ def build_model_from( links: list[LinkDescription], joints: list[JointDescription], frames: list[LinkDescription] | None = None, - collisions: list[CollisionShape] = (), + collisions: tuple[CollisionShape, ...] = (), fixed_base: bool = False, base_link_name: str | None = None, considered_joints: Sequence[str] | None = None, @@ -87,7 +87,7 @@ def build_model_from( for collision_shape in collisions: # Get all the collidable points of the shape - coll_points = list(collision_shape.collidable_points) + coll_points = tuple(collision_shape.collidable_points) # Assume they have an unique parent link if not len(set({cp.parent_link.name for cp in coll_points})) == 1: @@ -111,7 +111,7 @@ def build_model_from( continue # Create a new collision shape - new_collision_shape = CollisionShape(collidable_points=[]) + new_collision_shape = CollisionShape(collidable_points=()) final_collisions.append(new_collision_shape) # If the frame was found, update the collidable points' pose and add them @@ -133,15 +133,15 @@ def build_model_from( ), ) - # Store the updated collision - new_collision_shape.collidable_points.append(moved_cp) + # Store the updated collision. + new_collision_shape.collidable_points += (moved_cp,) # Build the model model = ModelDescription( name=name, root_pose=kinematic_graph.root_pose, fixed_base=fixed_base, - collision_shapes=final_collisions, + collision_shapes=tuple(final_collisions), root=kinematic_graph.root, joints=kinematic_graph.joints, frames=kinematic_graph.frames, @@ -174,7 +174,7 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: links=list(self.links_dict.values()), joints=self.joints, frames=self.frames, - collisions=self.collision_shapes, + collisions=tuple(self.collision_shapes), fixed_base=self.fixed_base, base_link_name=list(iter(self))[0].name, model_pose=self.root_pose, @@ -243,3 +243,23 @@ def all_enabled_collidable_points(self) -> list[CollidablePoint]: # Return enabled collidable points return [cp for cp in all_collidable_points if cp.enabled] + + def __eq__(self, other: ModelDescription) -> bool: + + if not isinstance(other, ModelDescription): + return False + + return hash(self) == hash(other) + + def __hash__(self) -> int: + + return hash( + ( + hash(self.name), + hash(self.fixed_base), + hash(self.root), + hash(tuple(self.joints)), + hash(tuple(self.frames)), + hash(self.root_pose), + ) + ) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index fa5ff07ef..7588af660 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -31,14 +31,21 @@ class RootPose(NamedTuple): root_position: npt.NDArray = np.zeros(3) root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0]) + def __hash__(self) -> int: + + return hash( + ( + hash(tuple(self.root_position.tolist())), + hash(tuple(self.root_quaternion.tolist())), + ) + ) + def __eq__(self, other: RootPose) -> bool: if not isinstance(other, RootPose): return False - return np.allclose(self.root_position, other.root_position) and np.allclose( - self.root_quaternion, other.root_quaternion - ) + return hash(self) == hash(other) @dataclasses.dataclass(frozen=True) @@ -54,22 +61,24 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]): """ root: descriptions.LinkDescription - frames: list[descriptions.LinkDescription] = dataclasses.field(default_factory=list) + frames: list[descriptions.LinkDescription] = dataclasses.field( + default_factory=list, hash=False, compare=False + ) joints: list[descriptions.JointDescription] = dataclasses.field( - default_factory=list + default_factory=list, hash=False, compare=False ) root_pose: RootPose = dataclasses.field(default_factory=lambda: RootPose()) # Private attribute storing optional additional info. _extra_info: dict[str, Any] = dataclasses.field( - repr=False, compare=False, default_factory=dict + default_factory=dict, repr=False, hash=False, compare=False ) # Private attribute storing the unconnected joints from the parsed model and # the joints removed after model reduction. _joints_removed: list[descriptions.JointDescription] = dataclasses.field( - default_factory=list, repr=False, compare=False + default_factory=list, repr=False, hash=False, compare=False ) @functools.cached_property @@ -98,14 +107,17 @@ def __post_init__(self) -> None: for index, link in enumerate(self): link.mutable(validate=False).index = index - # Get the names of the links and frames. + # Get the names of the links, frames, and joints. link_names = [l.name for l in self] frame_names = [f.name for f in self.frames] + joint_names = [j.name for j in self.joints] # Make sure that they are unique. assert len(link_names) == len(set(link_names)) assert len(frame_names) == len(set(frame_names)) + assert len(joint_names) == len(set(joint_names)) assert set(link_names).isdisjoint(set(frame_names)) + assert set(link_names).isdisjoint(set(joint_names)) # Order frames with their name. super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name)) @@ -251,7 +263,7 @@ def _create_graph( # Reset the connections of the root link. for link in links_dict.values(): - link.children = [] + link.children = tuple() # Couple links and joints creating the kinematic graph. for joint in joints: @@ -268,7 +280,8 @@ def _create_graph( # Assign link's children and make sure they are unique. if child_link.name not in {l.name for l in parent_link.children}: - parent_link.children.append(child_link) + with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION): + parent_link.children = parent_link.children + (child_link,) # Collect all the links of the kinematic graph. all_links_in_graph = list( @@ -315,7 +328,7 @@ def _create_graph( # Update the unconnected links by removing their children. The other properties # are left untouched, it's caller responsibility to post-process them if needed. for link in unconnected_links: - link.children = [] + link.children = tuple() msg = "Link '{}' won't be part of the kinematic graph because unconnected" logging.debug(msg=msg.format(link.name)) @@ -796,6 +809,7 @@ def transform(self, name: str) -> npt.NDArray: # Get the joint. joint = self.graph.joints_dict[name] + assert joint.name == name # Get the transform of the parent link. M_H_L = self.transform(name=joint.parent.name) From 3f41551df5409083126dd2a00c553b127501c1c0 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 12 Jun 2024 16:48:57 +0200 Subject: [PATCH 23/26] Include ModelDescription in JaxSimModel hash --- src/jaxsim/api/kin_dyn_parameters.py | 5 ++++- src/jaxsim/api/model.py | 15 ++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 01c2d4eca..c185b83fc 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import jax_dataclasses import jaxlie +import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp @@ -220,7 +221,9 @@ def __hash__(self) -> int: ( hash(self.number_of_links()), hash(self.number_of_joints()), - hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())), + hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())), + hash(self._parent_array), + hash(self._support_body_array_bool), ) ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 656732bce..4498a8a92 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -17,7 +17,7 @@ import jaxsim.parsers.descriptions import jaxsim.typing as jtp from jaxsim.math import Cross -from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability +from jaxsim.utils import JaxsimDataclass, Mutability from .common import VelRepr @@ -42,13 +42,9 @@ class JaxSimModel(JaxsimDataclass): default=None, repr=False, compare=False, hash=False ) - _description: Static[ - HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None] - ] = dataclasses.field(default=None, repr=False, compare=False, hash=False) - - @property - def description(self) -> jaxsim.parsers.descriptions.ModelDescription: - return self._description.get() + description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = ( + dataclasses.field(default=None, repr=False, compare=False, hash=False) + ) def __eq__(self, other: JaxSimModel) -> bool: @@ -62,6 +58,7 @@ def __hash__(self) -> int: return hash( ( hash(self.model_name), + hash(self.description), hash(self.kin_dyn_parameters), ) ) @@ -158,7 +155,7 @@ def build( # Build the model model = JaxSimModel( model_name=model_name, - _description=HashlessObject(obj=model_description), + description=model_description, kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( model_description=model_description ), From bf9d7a2f30b13528dad254b196c2cadd5dd32c24 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 12 Jun 2024 16:49:02 +0200 Subject: [PATCH 24/26] Make sure that mutable attributes do not get cross-altered in reduction --- src/jaxsim/api/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 4498a8a92..3fafaa1ee 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -313,10 +313,9 @@ def reduce( new_joints = set(model.joint_names()) - set(locked_joint_positions) raise ValueError(f"Passed joints not existing in the model: {new_joints}") - # Copy the model description with a deep copy of the joints. - intermediate_description = dataclasses.replace( - model.description, joints=copy.deepcopy(model.description.joints) - ) + # Operate on a deep copy of the model description in order to prevent problems + # when mutable attributes are updated. + intermediate_description = copy.deepcopy(model.description) # Update the initial position of the joints. # This is necessary to compute the correct pose of the link pairs connected From 8947957fe7f24cf18553b9f3ab03eae37f5a185f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 12 Jun 2024 17:09:07 +0200 Subject: [PATCH 25/26] Minor import changes --- src/jaxsim/parsers/rod/parser.py | 2 +- src/jaxsim/parsers/rod/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 73bead465..752345cd3 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -7,7 +7,7 @@ import rod from jaxsim import logging -from jaxsim.math.quaternion import Quaternion +from jaxsim.math import Quaternion from jaxsim.parsers import descriptions, kinematic_graph from . import utils diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index a001a1da7..aa0c6b128 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -6,7 +6,7 @@ import rod import jaxsim.typing as jtp -from jaxsim.math.inertia import Inertia +from jaxsim.math import Inertia from jaxsim.parsers import descriptions From 82502a332edc22a5a7916720b11762debcb516cb Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 12 Jun 2024 17:09:45 +0200 Subject: [PATCH 26/26] Minor typing update --- src/jaxsim/parsers/rod/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index aa0c6b128..d5ebbb15e 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -59,9 +59,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: return M_L.astype(dtype=float) -def joint_to_joint_type( - joint: rod.Joint, -) -> descriptions.JointType: +def joint_to_joint_type(joint: rod.Joint) -> int: """ Extract the joint type from an SDF joint. @@ -69,7 +67,7 @@ def joint_to_joint_type( joint: The parsed SDF joint. Returns: - The corresponding joint type description. + The integer corresponding to the joint type. """ axis = joint.axis @@ -138,7 +136,7 @@ def create_box_collision( collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, - position=corner, + position=np.array(corner), enabled=True, ) for corner in box_corners_wrt_link.T @@ -197,7 +195,7 @@ def fibonacci_sphere(samples: int) -> npt.NDArray: collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, - position=point, + position=np.array(point), enabled=True, ) for point in sphere_points_wrt_link.T