Skip to content

Commit

Permalink
Make compute_contact_forces return `tuple[jtp.Vector, tuple[Any, ..…
Browse files Browse the repository at this point in the history
….]]`
  • Loading branch information
flferretti committed Jun 18, 2024
1 parent ab7a688 commit dfebf77
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def collidable_point_dynamics(
# collidable point, and the corresponding material deformation rate.
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.compute_contact_forces)(
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
)

Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def compute_contact_forces(
position: jtp.Vector,
velocity: jtp.Vector,
tangential_deformation: jtp.Vector,
) -> tuple[jtp.Vector, jtp.Vector]:
) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]:
"""
Compute the contact forces and material deformation rate.
Expand Down Expand Up @@ -237,7 +237,7 @@ def with_no_friction():
# Compute lin-ang 6D forces (inertial representation)
W_f = W_Xf_CW @ CW_f

return W_f,
return W_f, (,)

# =========================
# Compute tangential forces
Expand All @@ -255,7 +255,7 @@ def with_friction():
active_contact = pz < self.terrain.height(x=px, y=py)

def above_terrain():
return jnp.zeros(6),
return jnp.zeros(6), (,)

def below_terrain():
# Decompose the velocity in normal and tangential components
Expand Down Expand Up @@ -311,9 +311,9 @@ def slipping_contact():
W_f = W_Xf_CW @ CW_f

# Return the 6D force in the world frame and the deformation derivative
return W_f,
return W_f, (,)

# (W_f, )
# (W_f, (ṁ,))
return jax.lax.cond(
pred=active_contact,
true_fun=lambda _: below_terrain(),
Expand Down
3 changes: 2 additions & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,10 @@ def close_over_inputs_and_parameters(
m: jtp.VectorLike,
params: SoftContactsParams,
) -> tuple[jtp.Vector, jtp.Vector]:
return SoftContacts(parameters=params).compute_contact_forces(
W_f_Ci, (CW_ṁ,) = SoftContacts(parameters=params).compute_contact_forces(
position=p, velocity=v, tangential_deformation=m
)
return W_f_Ci, CW_ṁ

# Check derivatives against finite differences.
check_grads(
Expand Down

0 comments on commit dfebf77

Please sign in to comment.