From 4e7f13b1c144a660961b34de4d7b1115eb61feab Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 10 Jun 2024 09:42:50 +0200 Subject: [PATCH] =?UTF-8?q?Test=20J=CC=87=20computation=20against=20AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_api_link.py | 65 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 8816cc46e..d5b7c48be 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -3,6 +3,7 @@ import pytest import jaxsim.api as js +import jaxsim.math from jaxsim import VelRepr from . import utils_idyntree @@ -250,4 +251,66 @@ def test_link_jacobian_derivative( )(js.link.names_to_idxs(model=model, link_names=model.link_names())) # Compare the two computations. - assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(O_a_bias_WL) + assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx( + O_a_bias_WL, abs=1e-9 + ) + + # Compute the plain Jacobian. + # This function will be used to compute the Jacobian derivative with AD. + # Given q, computing J̇ by AD-ing this function should work out-of-the-box with + # all velocity representations, that are handled internally when computing J. + def J(q) -> jax.Array: + + data_ad = js.data.JaxSimModelData.zero( + model=model, velocity_representation=data.velocity_representation + ) + + data_ad = data_ad.reset_base_position(base_position=q[:3]) + data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7]) + data_ad = data_ad.reset_joint_positions(positions=q[7:]) + + O_J_WL_I = js.model.generalized_free_floating_jacobian( + model=model, data=data_ad + ) + + return O_J_WL_I + + def compute_q(data: js.data.JaxSimModelData) -> jax.Array: + + q = jnp.hstack( + [data.base_position(), data.base_orientation(), data.joint_positions()] + ) + + return q + + def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: + + with data.switch_velocity_representation(VelRepr.Body): + B_ω_WB = data.base_velocity()[3:6] + + with data.switch_velocity_representation(VelRepr.Mixed): + W_ṗ_B = data.base_velocity()[0:3] + + W_Q̇_B = jaxsim.math.Quaternion.derivative( + quaternion=data.base_orientation(), + omega=B_ω_WB, + omega_in_body_fixed=True, + K=0.0, + ).squeeze() + + q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()]) + + return q̇ + + # Compute q and q̇. + q = compute_q(data) + q̇ = compute_q̇(data) + + # Compute dJ/dt with AD. + dJ_dq = jax.jacfwd(J, argnums=0)(q) + O_J̇_ad_WL_I = jnp.einsum("ijkq,q->ijk", dJ_dq, q̇) + + assert O_J̇_ad_WL_I == pytest.approx(O_J̇_WL_I) + assert jnp.einsum("l6g,g->l6", O_J̇_ad_WL_I, I_ν) == pytest.approx( + jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) + )