diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 7b593e897..b83571515 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -154,7 +154,7 @@ def collidable_point_dynamics( # Note that the material deformation rate is always returned in the mixed frame # C[W] = (W_p_C, [W]). This is convenient for integration purpose. W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)( - W_p_Ci, W_ṗ_Ci, data.state.contact_state.tangential_deformation + W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation ) case _: diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 5d6463c85..86f617d1d 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -132,7 +132,7 @@ def system_velocity_dynamics( W_f_Ci = None # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. - ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float) + ṁ = jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float) if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces applied to each collidable point and the diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index be8479d9e..1b37b5f69 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -1,10 +1,13 @@ from __future__ import annotations +import importlib + import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.api.soft_contacts import SoftContactsState from jaxsim.utils import JaxsimDataclass @@ -117,11 +120,11 @@ class ODEState(JaxsimDataclass): Attributes: physics_model: The state of the physics model. - soft_contacts: The state of the soft-contacts model. + contacts_state: The state of the contacts model. """ physics_model: PhysicsModelState - soft_contacts: SoftContactsState + contacts_state: js.contact.ContactsState @staticmethod def build_from_jaxsim_model( @@ -159,6 +162,15 @@ def build_from_jaxsim_model( `JaxSimModel` and initialized to zero. """ + # Get the contact model from the `JaxSimModel` + prefix = type(model.contact_model).__name__.split("Contact")[0] + + if prefix: + module_name = f"{prefix.lower()}_contacts" + class_name = f"{prefix.capitalize()}ContactsState" + else: + raise ValueError("Unable to determine contact state class prefix.") + return ODEState.build( model=model, physics_model_state=PhysicsModelState.build_from_jaxsim_model( @@ -170,24 +182,30 @@ def build_from_jaxsim_model( base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, ), - soft_contacts_state=SoftContactsState.build_from_jaxsim_model( + contacts_state=getattr( + importlib.import_module(f"jaxsim.api.{module_name}"), class_name + ).build_from_jaxsim_model( model=model, - tangential_deformation=tangential_deformation, + **( + dict(tangential_deformation=tangential_deformation) + if tangential_deformation is not None + else dict() + ), ), ) @staticmethod def build( physics_model_state: PhysicsModelState | None = None, - soft_contacts_state: SoftContactsState | None = None, + contacts_state: js.contact.ContactsState | None = None, model: js.model.JaxSimModel | None = None, ) -> ODEState: """ - Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`. + Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`. Args: physics_model_state: The state of the physics model. - soft_contacts_state: The state of the soft-contacts model. + contacts_state: The state of the contacts model. model: The `JaxSimModel` associated with the ODE state. Returns: @@ -200,14 +218,33 @@ def build( else PhysicsModelState.zero(model=model) ) - soft_contacts_state = ( - soft_contacts_state - if soft_contacts_state is not None + # Get the contact model from the `JaxSimModel` + try: + prefix = type(model.contact_model).__name__.split("Contact")[0] + except AttributeError: + logging.warning( + "Unable to determine contact state class prefix. Using default soft contacts." + ) + prefix = "Soft" + + module_name = f"{prefix.lower()}_contacts" + class_name = f"{prefix.capitalize()}ContactsState" + + try: + state_cls = getattr( + importlib.import_module(f"jaxsim.api.{module_name}"), class_name + ) + except ImportError as e: + raise e + + contacts_state = ( + contacts_state + if contacts_state is not None else SoftContactsState.zero(model=model) ) return ODEState( - physics_model=physics_model_state, soft_contacts=soft_contacts_state + physics_model=physics_model_state, contacts_state=contacts_state ) @staticmethod @@ -237,7 +274,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool: `True` if the ODE state is valid for the given model, `False` otherwise. """ - return self.physics_model.valid(model=model) and self.soft_contacts.valid( + return self.physics_model.valid(model=model) and self.contacts_state.valid( model=model ) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 6e9e5f793..c843b9024 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -342,7 +342,7 @@ def test_ad_integration( s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) - m = data.state.soft_contacts.tangential_deformation + m = data.state.contacts_state.tangential_deformation # Inputs. W_f_L = references.link_forces(model=model) @@ -417,7 +417,7 @@ def step( xf_s = data_xf.joint_positions(model=model) xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities(model=model) - xf_m = data_xf.state.soft_contacts.tangential_deformation + xf_m = data_xf.state.contacts_state.tangential_deformation return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m