Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add algorithm to compute the free-floating Coriolis matrix #172

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jaxsim.api as js
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability

from .common import VelRepr
Expand Down Expand Up @@ -871,6 +872,126 @@ def free_floating_mass_matrix(
raise ValueError(data.velocity_representation)


@jax.jit
def free_floating_coriolis_matrix(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the free-floating Coriolis matrix of the model.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The free-floating Coriolis matrix of the model.

Note:
This function, contrarily to other quantities of the equations of motion,
does not exploit any iterative algorithm. Therefore, the computation of
the Coriolis matrix may be much slower than other quantities.
"""

# We perform all the calculation in body-fixed.
# The Coriolis matrix computed in this representation is converted later
# to the active representation stored in data.
with data.switch_velocity_representation(VelRepr.Body):

B_ν = data.generalized_velocity()

# Doubly-left free-floating Jacobian.
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)

# Doubly-left free-floating Jacobian derivative.
L_J̇_WL_B = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

L_M_L = link_spatial_inertia_matrices(model=model)

# Body-fixed link velocities.
# Note: we could have called link.velocity() instead of computing it ourselves,
# but since we need the link Jacobians later, we can save a double calculation.
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)

# Compute the contribution of each link to the Coriolis matrix.
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:

return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)

C_B_links = jax.vmap(compute_link_contribution)(
L_M_L,
L_v_WL,
L_J_WL_B,
L_J̇_WL_B,
)

# We need to adjust the Coriolis matrix for fixed-base models.
# In this case, the base link does not contribute to the matrix, and we need to zero
# the off-diagonal terms mapping joint quantities onto the base configuration.
if model.floating_base():
C_B = C_B_links.sum(axis=0)
else:
C_B = C_B_links[1:].sum(axis=0)
C_B = C_B.at[0:6, 6:].set(0.0)
C_B = C_B.at[6:, 0:6].set(0.0)

# Adjust the representation of the Coriolis matrix.
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
match data.velocity_representation:

case VelRepr.Body:
return C_B

case VelRepr.Inertial:

n = model.dofs()
W_H_B = data.base_transform()
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))

with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))

with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)

C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)

return C

case VelRepr.Mixed:

n = model.dofs()
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))

with data.switch_velocity_representation(VelRepr.Mixed):
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))

BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)

B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))

with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)

C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)

return C

case _:
raise ValueError(data.velocity_representation)


@jax.jit
def inverse_dynamics(
model: JaxSimModel,
Expand Down Expand Up @@ -931,8 +1052,6 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

from jaxsim.math import Cross

W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
C_v_WC = C_X_W @ W_v_WC
Expand Down
90 changes: 90 additions & 0 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import rod

import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -319,6 +320,95 @@ def test_model_jacobian(
assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name


def test_coriolis_matrix(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)

# =====
# Tests
# =====

I_ν = data.generalized_velocity()
C = js.model.free_floating_coriolis_matrix(model=model, data=data)

h = js.model.free_floating_bias_forces(model=model, data=data)
g = js.model.free_floating_gravity_forces(model=model, data=data)
Cν = h - g

assert C @ I_ν == pytest.approx(Cν)

# Compute the free-floating mass matrix.
# This function will be used to compute the Ṁ with AD.
# Given q, computing Ṁ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing M.
def M(q) -> jax.Array:

data_ad = js.data.JaxSimModelData.zero(
model=model, velocity_representation=data.velocity_representation
)

data_ad = data_ad.reset_base_position(base_position=q[:3])
data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7])
data_ad = data_ad.reset_joint_positions(positions=q[7:])

M = js.model.free_floating_mass_matrix(model=model, data=data_ad)

return M

def compute_q(data: js.data.JaxSimModelData) -> jax.Array:

q = jnp.hstack(
[data.base_position(), data.base_orientation(), data.joint_positions()]
)

return q

def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:

with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity()[3:6]

with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]

W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()])

return q̇

# Compute q and q̇.
q = compute_q(data)
q̇ = compute_q̇(data)

# Compute Ṁ with AD.
dM_dq = jax.jacfwd(M, argnums=0)(q)
Ṁ = jnp.einsum("ijq,q->ij", dM_dq, q̇)

# We need to zero the blocks projecting joint variables to the base configuration
# for fixed-base models.
if not model.floating_base():
Ṁ = Ṁ.at[0:6, 6:].set(0)
Ṁ = Ṁ.at[6:, 0:6].set(0)

# Ensure that (Ṁ - 2C) is skew symmetric.
assert Ṁ - C - C.T == pytest.approx(0)


def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,5 @@ def test_box_with_zero_gravity(
assert data.base_position() == pytest.approx(
data0.base_position()
+ 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
rel=1e-4,
abs=1e-3,
)