From c2a457b9c4f4b31474150001e36d9a007d95faa7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 31 May 2024 11:33:59 +0200 Subject: [PATCH 1/5] Add jaxsim.api.contact.transforms --- src/jaxsim/api/contact.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 64642b8c5..2f1aec027 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -269,3 +269,40 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: ) return sc_parameters + + +@jax.jit +def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array: + r""" + Return the pose of the collidable points. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + The stacked SE(3) matrices of all collidable points. + + Note: + Each collidable point is implicitly associated with a frame + :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the + collidable point and :math:`[L]` is the orientation frame of the link it is + rigidly attached to. + """ + + # Get the transforms of the parent link of all collidable points. + W_H_L = jax.vmap( + lambda parent_link_idx: js.link.transform( + model=model, data=data, link_index=parent_link_idx + ) + )(jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)) + + # Build the link-to-point transform from the displacement between the link frame L + # and the implicit contact frame C. + L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))( + model.kin_dyn_parameters.contact_parameters.point + ) + + # Compose the work-to-link and link-to-point transforms. + return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) + From 483ac6a58a2eb54b0599bf110ab4c623c04ca768 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 31 May 2024 11:34:57 +0200 Subject: [PATCH 2/5] Add jaxsim.api.contact.jacobian --- src/jaxsim/api/contact.py | 84 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 2f1aec027..56a64187c 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -306,3 +306,87 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt # Compose the work-to-link and link-to-point transforms. return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) + +@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +def jacobian( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_vel_repr: VelRepr | None = None, +) -> jtp.Array: + r""" + Return the free-floating Jacobian of the collidable points. + + Args: + model: The model to consider. + data: The data of the considered model. + output_vel_repr: + The output velocity representation of the free-floating jacobian. + + Returns: + The stacked 6×(6+n) free-floating jacobians of the frames associated to the + collidable points. + + Note: + Each collidable point is implicitly associated with a frame + :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the + collidable point and :math:`[L]` is the orientation frame of the link it is + rigidly attached to. + """ + + output_vel_repr = ( + output_vel_repr if output_vel_repr is not None else data.velocity_representation + ) + + # For each collidable point, get the Jacobians of their parent link. + # In inertial-fixed output representation, the Jacobian of the parent link is also + # the Jacobian of the frame C implicitly associated with the collidable point. + W_J_WC = W_J_WL = jax.vmap( + lambda parent_link_idx: js.link.jacobian( + model=model, + data=data, + link_index=parent_link_idx, + output_vel_repr=VelRepr.Inertial, + ) + )(jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)) + + # Adjust the output representation. + match output_vel_repr: + + case VelRepr.Inertial: + O_J_WC = W_J_WC + + case VelRepr.Body: + + W_H_C = transforms(model=model, data=data) + + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + C_X_W = jaxsim.math.Adjoint.from_transform( + transform=W_H_C, inverse=True + ) + C_J_WCi = C_X_W @ W_J_WC + return C_J_WCi + + O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + case VelRepr.Mixed: + + W_H_C = transforms(model=model, data=data) + + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) + + CW_X_W = jaxsim.math.Adjoint.from_transform( + transform=W_H_CW, inverse=True + ) + + CW_J_WC = CW_X_W @ W_J_WC + return CW_J_WC + + O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + case _: + raise ValueError(output_vel_repr) + + return O_J_WC From 3ec01244691b71bf1da4f1f9c03f05678f8a100e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 31 May 2024 11:35:34 +0200 Subject: [PATCH 3/5] Add new test of jaxsim.api.contact module --- tests/test_contact.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/test_contact.py diff --git a/tests/test_contact.py b/tests/test_contact.py new file mode 100644 index 000000000..4aff7c8ad --- /dev/null +++ b/tests/test_contact.py @@ -0,0 +1,37 @@ +import jax +import pytest + +import jaxsim.api as js +from jaxsim import VelRepr + + +def test_collidable_point_jacobians( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + # ===== + # Tests + # ===== + + # Compute the velocity of the collidable points with a RBDA. + # This function always returns the linear part of the mixed velocity of the + # implicit frame C corresponding to the collidable point. + W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + + # Compute the generalized velocity and the free-floating Jacobian of the frame C. + ν = data.generalized_velocity() + CL_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) + + # Compute the velocity of the collidable points using the Jacobians. + v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CL_J_WC, ν) + + assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3]) From acd8a461cdb44d17ed813f6bedfb441f330e7582 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 31 May 2024 11:36:09 +0200 Subject: [PATCH 4/5] Minor updates --- src/jaxsim/api/contact.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 56a64187c..36829bc45 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -32,6 +32,7 @@ def collidable_point_kinematics( from jaxsim.rbda import collidable_points + # Switch to inertial-fixed since the RBDAs expect velocities in this representation. with data.switch_velocity_representation(VelRepr.Inertial): W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel( model=model, @@ -61,7 +62,9 @@ def collidable_point_positions( The position of the collidable points in the world frame. """ - return collidable_point_kinematics(model=model, data=data)[0] + W_p_Ci, _ = collidable_point_kinematics(model=model, data=data) + + return W_p_Ci @jax.jit @@ -79,7 +82,9 @@ def collidable_point_velocities( The 3D velocity of the collidable points. """ - return collidable_point_kinematics(model=model, data=data)[1] + _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data) + + return W_ṗ_Ci @jax.jit From bcd35dca8713c230d8d45f5f1f327e7cf7ff8b5c Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 31 May 2024 15:28:29 +0200 Subject: [PATCH 5/5] Address review Co-authored-by: Alessandro Croci Co-authored-by: Filippo Luca Ferretti --- src/jaxsim/api/contact.py | 4 ++-- tests/test_contact.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 36829bc45..4d66d240f 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -369,8 +369,8 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: C_X_W = jaxsim.math.Adjoint.from_transform( transform=W_H_C, inverse=True ) - C_J_WCi = C_X_W @ W_J_WC - return C_J_WCi + C_J_WC = C_X_W @ W_J_WC + return C_J_WC O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) diff --git a/tests/test_contact.py b/tests/test_contact.py index 4aff7c8ad..40c5eca37 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -29,9 +29,9 @@ def test_collidable_point_jacobians( # Compute the generalized velocity and the free-floating Jacobian of the frame C. ν = data.generalized_velocity() - CL_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) + CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) # Compute the velocity of the collidable points using the Jacobians. - v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CL_J_WC, ν) + v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3])