diff --git a/tests/test_api_link.py b/tests/test_api_link.py index ebe0e90fb..1e808687d 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -223,6 +223,74 @@ def test_link_bias_acceleration( Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index) assert pytest.approx(Jν_idt) == Jν_js + # Test that the conversion of the link bias acceleration works as expected. + match data.velocity_representation: + + # We exclude the mixed representation because converting the acceleration is + # more complex than using the plain 6D transform matrix. + case VelRepr.Mixed: + pass + + # Inertial-fixed to body-fixed conversion. + case VelRepr.Inertial: + + W_H_L = js.model.forward_kinematics(model=model, data=data) + + W_a_bias_WL = jax.vmap( + lambda index: js.link.bias_acceleration( + model=model, data=data, link_index=index + ) + )(jnp.arange(model.number_of_links())) + + with data.switch_velocity_representation(VelRepr.Body): + + W_X_L = jax.vmap( + lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L) + )(W_H_L) + + L_a_bias_WL = jax.vmap( + lambda index: js.link.bias_acceleration( + model=model, data=data, link_index=index + ) + )(jnp.arange(model.number_of_links())) + + W_a_bias_WL_converted = jax.vmap( + lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL + )(W_X_L, L_a_bias_WL) + + assert W_a_bias_WL == pytest.approx(W_a_bias_WL_converted) + + # Body-fixed to inertial-fixed conversion. + case VelRepr.Body: + + W_H_L = js.model.forward_kinematics(model=model, data=data) + + L_a_bias_WL = jax.vmap( + lambda index: js.link.bias_acceleration( + model=model, data=data, link_index=index + ) + )(jnp.arange(model.number_of_links())) + + with data.switch_velocity_representation(VelRepr.Inertial): + + L_X_W = jax.vmap( + lambda W_H_L: jaxsim.math.Adjoint.from_transform( + transform=W_H_L, inverse=True + ) + )(W_H_L) + + W_a_bias_WL = jax.vmap( + lambda index: js.link.bias_acceleration( + model=model, data=data, link_index=index + ) + )(jnp.arange(model.number_of_links())) + + L_a_bias_WL_converted = jax.vmap( + lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL + )(L_X_W, W_a_bias_WL) + + assert L_a_bias_WL == pytest.approx(L_a_bias_WL_converted) + def test_link_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel,