From f3e35b1033955e00c3b6c7c1c62124e668a3fa3d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:26:02 +0200 Subject: [PATCH] Refactor contact forces sum in `api.ode.system_velocity_dynamics` Co-authored-by: Alessandro Croci --- src/jaxsim/api/ode.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 5d6463c85..f0bdf7757 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -150,14 +150,12 @@ def system_velocity_dynamics( # Sum the forces of all collidable points rigidly attached to a body. # Since the contact forces W_f_Ci are expressed in the world frame, # we don't need any coordinate transformation. - W_f_Li_terrain = jax.vmap( - lambda nc: ( - jnp.vstack( - jnp.equal(parent_link_index_of_collidable_points, nc).astype(int) - ) - * W_f_Ci - ).sum(axis=0) - )(jnp.arange(model.number_of_links())) + W_f_Li_terrain = jnp.where( + parent_link_index_of_collidable_points[:, jnp.newaxis] + == jnp.arange(model.number_of_links()), + W_f_Ci, + 0.0, + ).sum(axis=0) # ==================== # Enforce joint limits