Skip to content

Commit

Permalink
Add test of Coriolis matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 10, 2024
1 parent ae77abe commit 31ae200
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import rod

import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -319,6 +320,95 @@ def test_model_jacobian(
assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name


def test_coriolis_matrix(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)

# =====
# Tests
# =====

I_ν = data.generalized_velocity()
C = js.model.free_floating_coriolis_matrix(model=model, data=data)

h = js.model.free_floating_bias_forces(model=model, data=data)
g = js.model.free_floating_gravity_forces(model=model, data=data)
= h - g

assert C @ I_ν == pytest.approx()

# Compute the free-floating mass matrix.
# This function will be used to compute the Ṁ with AD.
# Given q, computing Ṁ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing M.
def M(q) -> jax.Array:

data_ad = js.data.JaxSimModelData.zero(
model=model, velocity_representation=data.velocity_representation
)

data_ad = data_ad.reset_base_position(base_position=q[:3])
data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7])
data_ad = data_ad.reset_joint_positions(positions=q[7:])

M = js.model.free_floating_mass_matrix(model=model, data=data_ad)

return M

def compute_q(data: js.data.JaxSimModelData) -> jax.Array:

q = jnp.hstack(
[data.base_position(), data.base_orientation(), data.joint_positions()]
)

return q

def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:

with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity()[3:6]

with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]

W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

= jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()])

return

# Compute q and q̇.
q = compute_q(data)
= compute_q̇(data)

# Compute Ṁ with AD.
dM_dq = jax.jacfwd(M, argnums=0)(q)
= jnp.einsum("ijq,q->ij", dM_dq, )

# We need to zero the blocks projecting joint variables to the base configuration
# for fixed-base models.
if not model.floating_base():
= .at[0:6, 6:].set(0)
= .at[6:, 0:6].set(0)

# Ensure that (Ṁ - 2C) is skew symmetric.
assert - C - C.T == pytest.approx(0)


def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
Expand Down

0 comments on commit 31ae200

Please sign in to comment.