diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 65bdb009b..303339e55 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -16,6 +16,7 @@ import jaxsim.api as js import jaxsim.parsers.descriptions import jaxsim.typing as jtp +from jaxsim.math import Cross from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability from .common import VelRepr @@ -871,6 +872,126 @@ def free_floating_mass_matrix( raise ValueError(data.velocity_representation) +@jax.jit +def free_floating_coriolis_matrix( + model: JaxSimModel, data: js.data.JaxSimModelData +) -> jtp.Matrix: + """ + Compute the free-floating Coriolis matrix of the model. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + The free-floating Coriolis matrix of the model. + + Note: + This function, contrarily to other quantities of the equations of motion, + does not exploit any iterative algorithm. Therefore, the computation of + the Coriolis matrix may be much slower than other quantities. + """ + + # We perform all the calculation in body-fixed. + # The Coriolis matrix computed in this representation is converted later + # to the active representation stored in data. + with data.switch_velocity_representation(VelRepr.Body): + + B_ν = data.generalized_velocity() + + # Doubly-left free-floating Jacobian. + L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data) + + # Doubly-left free-floating Jacobian derivative. + L_J̇_WL_B = jax.vmap( + lambda link_index: js.link.jacobian_derivative( + model=model, data=data, link_index=link_index + ) + )(js.link.names_to_idxs(model=model, link_names=model.link_names())) + + L_M_L = link_spatial_inertia_matrices(model=model) + + # Body-fixed link velocities. + # Note: we could have called link.velocity() instead of computing it ourselves, + # but since we need the link Jacobians later, we can save a double calculation. + L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B) + + # Compute the contribution of each link to the Coriolis matrix. + def compute_link_contribution(M, v, J, J̇) -> jtp.Array: + + return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇) + + C_B_links = jax.vmap(compute_link_contribution)( + L_M_L, + L_v_WL, + L_J_WL_B, + L_J̇_WL_B, + ) + + # We need to adjust the Coriolis matrix for fixed-base models. + # In this case, the base link does not contribute to the matrix, and we need to zero + # the off-diagonal terms mapping joint quantities onto the base configuration. + if model.floating_base(): + C_B = C_B_links.sum(axis=0) + else: + C_B = C_B_links[1:].sum(axis=0) + C_B = C_B.at[0:6, 6:].set(0.0) + C_B = C_B.at[6:, 0:6].set(0.0) + + # Adjust the representation of the Coriolis matrix. + # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. + match data.velocity_representation: + + case VelRepr.Body: + return C_B + + case VelRepr.Inertial: + + n = model.dofs() + W_H_B = data.base_transform() + B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) + B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) + + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + + B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) + + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) + + C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) + + return C + + case VelRepr.Mixed: + + n = model.dofs() + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) + + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + + B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) + + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) + + C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) + + return C + + case _: + raise ValueError(data.velocity_representation) + + @jax.jit def inverse_dynamics( model: JaxSimModel, @@ -931,8 +1052,6 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ - from jaxsim.math import Cross - W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint() C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() C_v_WC = C_X_W @ W_v_WC diff --git a/tests/test_api_model.py b/tests/test_api_model.py index f903a9f64..77e1107d4 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -7,6 +7,7 @@ import rod import jaxsim.api as js +import jaxsim.math from jaxsim import VelRepr from . import utils_idyntree @@ -319,6 +320,95 @@ def test_model_jacobian( assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name +def test_coriolis_matrix( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + _, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + # ===== + # Tests + # ===== + + I_ν = data.generalized_velocity() + C = js.model.free_floating_coriolis_matrix(model=model, data=data) + + h = js.model.free_floating_bias_forces(model=model, data=data) + g = js.model.free_floating_gravity_forces(model=model, data=data) + Cν = h - g + + assert C @ I_ν == pytest.approx(Cν) + + # Compute the free-floating mass matrix. + # This function will be used to compute the Ṁ with AD. + # Given q, computing Ṁ by AD-ing this function should work out-of-the-box with + # all velocity representations, that are handled internally when computing M. + def M(q) -> jax.Array: + + data_ad = js.data.JaxSimModelData.zero( + model=model, velocity_representation=data.velocity_representation + ) + + data_ad = data_ad.reset_base_position(base_position=q[:3]) + data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7]) + data_ad = data_ad.reset_joint_positions(positions=q[7:]) + + M = js.model.free_floating_mass_matrix(model=model, data=data_ad) + + return M + + def compute_q(data: js.data.JaxSimModelData) -> jax.Array: + + q = jnp.hstack( + [data.base_position(), data.base_orientation(), data.joint_positions()] + ) + + return q + + def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: + + with data.switch_velocity_representation(VelRepr.Body): + B_ω_WB = data.base_velocity()[3:6] + + with data.switch_velocity_representation(VelRepr.Mixed): + W_ṗ_B = data.base_velocity()[0:3] + + W_Q̇_B = jaxsim.math.Quaternion.derivative( + quaternion=data.base_orientation(), + omega=B_ω_WB, + omega_in_body_fixed=True, + K=0.0, + ).squeeze() + + q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()]) + + return q̇ + + # Compute q and q̇. + q = compute_q(data) + q̇ = compute_q̇(data) + + # Compute Ṁ with AD. + dM_dq = jax.jacfwd(M, argnums=0)(q) + Ṁ = jnp.einsum("ijq,q->ij", dM_dq, q̇) + + # We need to zero the blocks projecting joint variables to the base configuration + # for fixed-base models. + if not model.floating_base(): + Ṁ = Ṁ.at[0:6, 6:].set(0) + Ṁ = Ṁ.at[6:, 0:6].set(0) + + # Ensure that (Ṁ - 2C) is skew symmetric. + assert Ṁ - C - C.T == pytest.approx(0) + + def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, diff --git a/tests/test_simulations.py b/tests/test_simulations.py index f300a73a1..0bd96bf4e 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -169,5 +169,5 @@ def test_box_with_zero_gravity( assert data.base_position() == pytest.approx( data0.base_position() + 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, - rel=1e-4, + abs=1e-3, )