Skip to content

Commit

Permalink
Merge pull request #172 from ami-iit/coriolis_matrix
Browse files Browse the repository at this point in the history
Add algorithm to compute the free-floating Coriolis matrix
  • Loading branch information
diegoferigo committed Jun 11, 2024
2 parents ef02e71 + 31bbed9 commit 0422047
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 3 deletions.
123 changes: 121 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jaxsim.api as js
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability

from .common import VelRepr
Expand Down Expand Up @@ -871,6 +872,126 @@ def free_floating_mass_matrix(
raise ValueError(data.velocity_representation)


@jax.jit
def free_floating_coriolis_matrix(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the free-floating Coriolis matrix of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The free-floating Coriolis matrix of the model.
Note:
This function, contrarily to other quantities of the equations of motion,
does not exploit any iterative algorithm. Therefore, the computation of
the Coriolis matrix may be much slower than other quantities.
"""

# We perform all the calculation in body-fixed.
# The Coriolis matrix computed in this representation is converted later
# to the active representation stored in data.
with data.switch_velocity_representation(VelRepr.Body):

B_ν = data.generalized_velocity()

# Doubly-left free-floating Jacobian.
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)

# Doubly-left free-floating Jacobian derivative.
L_J̇_WL_B = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

L_M_L = link_spatial_inertia_matrices(model=model)

# Body-fixed link velocities.
# Note: we could have called link.velocity() instead of computing it ourselves,
# but since we need the link Jacobians later, we can save a double calculation.
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)

# Compute the contribution of each link to the Coriolis matrix.
def compute_link_contribution(M, v, J, ) -> jtp.Array:

return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ )

C_B_links = jax.vmap(compute_link_contribution)(
L_M_L,
L_v_WL,
L_J_WL_B,
L_J̇_WL_B,
)

# We need to adjust the Coriolis matrix for fixed-base models.
# In this case, the base link does not contribute to the matrix, and we need to zero
# the off-diagonal terms mapping joint quantities onto the base configuration.
if model.floating_base():
C_B = C_B_links.sum(axis=0)
else:
C_B = C_B_links[1:].sum(axis=0)
C_B = C_B.at[0:6, 6:].set(0.0)
C_B = C_B.at[6:, 0:6].set(0.0)

# Adjust the representation of the Coriolis matrix.
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
match data.velocity_representation:

case VelRepr.Body:
return C_B

case VelRepr.Inertial:

n = model.dofs()
W_H_B = data.base_transform()
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))

with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))

with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)

C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)

return C

case VelRepr.Mixed:

n = model.dofs()
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))

with data.switch_velocity_representation(VelRepr.Mixed):
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))

BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)

B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))

with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)

C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)

return C

case _:
raise ValueError(data.velocity_representation)


@jax.jit
def inverse_dynamics(
model: JaxSimModel,
Expand Down Expand Up @@ -931,8 +1052,6 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

from jaxsim.math import Cross

W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
C_v_WC = C_X_W @ W_v_WC
Expand Down
90 changes: 90 additions & 0 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import rod

import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -319,6 +320,95 @@ def test_model_jacobian(
assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name


def test_coriolis_matrix(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)

# =====
# Tests
# =====

I_ν = data.generalized_velocity()
C = js.model.free_floating_coriolis_matrix(model=model, data=data)

h = js.model.free_floating_bias_forces(model=model, data=data)
g = js.model.free_floating_gravity_forces(model=model, data=data)
= h - g

assert C @ I_ν == pytest.approx()

# Compute the free-floating mass matrix.
# This function will be used to compute the Ṁ with AD.
# Given q, computing Ṁ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing M.
def M(q) -> jax.Array:

data_ad = js.data.JaxSimModelData.zero(
model=model, velocity_representation=data.velocity_representation
)

data_ad = data_ad.reset_base_position(base_position=q[:3])
data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7])
data_ad = data_ad.reset_joint_positions(positions=q[7:])

M = js.model.free_floating_mass_matrix(model=model, data=data_ad)

return M

def compute_q(data: js.data.JaxSimModelData) -> jax.Array:

q = jnp.hstack(
[data.base_position(), data.base_orientation(), data.joint_positions()]
)

return q

def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:

with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity()[3:6]

with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]

W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

= jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()])

return

# Compute q and q̇.
q = compute_q(data)
= compute_q̇(data)

# Compute Ṁ with AD.
dM_dq = jax.jacfwd(M, argnums=0)(q)
= jnp.einsum("ijq,q->ij", dM_dq, )

# We need to zero the blocks projecting joint variables to the base configuration
# for fixed-base models.
if not model.floating_base():
= .at[0:6, 6:].set(0)
= .at[6:, 0:6].set(0)

# Ensure that (Ṁ - 2C) is skew symmetric.
assert - C - C.T == pytest.approx(0)


def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,5 @@ def test_box_with_zero_gravity(
assert data.base_position() == pytest.approx(
data0.base_position()
+ 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
rel=1e-4,
abs=1e-3,
)

0 comments on commit 0422047

Please sign in to comment.