From a9e3b2ffa813150c764a15bdd1d51d8b8b4dc93f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 18 Jun 2024 14:37:06 +0200 Subject: [PATCH] Make `compute_contact_forces` return `tuple[jtp.Vector, tuple[Any, ...]]` --- src/jaxsim/api/contact.py | 2 +- src/jaxsim/rbda/contacts/soft.py | 10 +++++----- tests/test_automatic_differentiation.py | 3 ++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 5de7e43ce..1fc9e0a28 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -149,7 +149,7 @@ def collidable_point_dynamics( # collidable point, and the corresponding material deformation rate. # Note that the material deformation rate is always returned in the mixed frame # C[W] = (W_p_C, [W]). This is convenient for integration purpose. - W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.compute_contact_forces)( + W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)( W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index f199efdb6..0328eef76 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -158,7 +158,7 @@ def compute_contact_forces( position: jtp.Vector, velocity: jtp.Vector, tangential_deformation: jtp.Vector, - ) -> tuple[jtp.Vector, jtp.Vector]: + ) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]: """ Compute the contact forces and material deformation rate. @@ -237,7 +237,7 @@ def with_no_friction(): # Compute lin-ang 6D forces (inertial representation) W_f = W_Xf_CW @ CW_f - return W_f, ṁ + return W_f, (ṁ,) # ========================= # Compute tangential forces @@ -255,7 +255,7 @@ def with_friction(): active_contact = pz < self.terrain.height(x=px, y=py) def above_terrain(): - return jnp.zeros(6), ṁ + return jnp.zeros(6), (ṁ,) def below_terrain(): # Decompose the velocity in normal and tangential components @@ -311,9 +311,9 @@ def slipping_contact(): W_f = W_Xf_CW @ CW_f # Return the 6D force in the world frame and the deformation derivative - return W_f, ṁ + return W_f, (ṁ,) - # (W_f, ṁ) + # (W_f, (ṁ,)) return jax.lax.cond( pred=active_contact, true_fun=lambda _: below_terrain(), diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index d27110137..31281ce9f 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -308,9 +308,10 @@ def close_over_inputs_and_parameters( m: jtp.VectorLike, params: SoftContactsParams, ) -> tuple[jtp.Vector, jtp.Vector]: - return SoftContacts(parameters=params).compute_contact_forces( + W_f_Ci, (CW_ṁ,) = SoftContacts(parameters=params).compute_contact_forces( position=p, velocity=v, tangential_deformation=m ) + return W_f_Ci, CW_ṁ # Check derivatives against finite differences. check_grads(