Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add algorithm to compute the Jacobian derivative of a link #169

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.numpy as jnp
import jax.scipy.linalg
import jaxlie
import numpy as np

Expand Down Expand Up @@ -337,6 +338,187 @@ def velocity(
return O_J_WL_I @ I_ν


@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the derivative of the free-floating jacobian of the link.

Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.

Returns:
The derivative of the 6×(6+n) free-floating jacobian of the link.

Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Compute the derivative of the doubly-left free-floating full jacobian.
B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_velocities=data.joint_velocities(),
)

# Compute the actual doubly-left free-floating jacobian derivative of the link
# by zeroing the columns not in the path π_B(L) using the boolean κ(i).
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B

# =====================================================
# Compute quantities to adjust the input representation
# =====================================================

In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))

match data.velocity_representation:

case VelRepr.Inertial:

W_H_B = data.base_transform()
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)

with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)

case VelRepr.Body:

B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
)

B_Ẋ_B = jnp.zeros(shape=(6, 6))

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)

case VelRepr.Mixed:

BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)

with data.switch_velocity_representation(VelRepr.Mixed):
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))

BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)

case _:
raise ValueError(data.velocity_representation)

# ======================================================
# Compute quantities to adjust the output representation
# ======================================================

match output_vel_repr:

case VelRepr.Inertial:

W_H_B = data.base_transform()
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB)

case VelRepr.Body:

O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
transform=B_H_L[link_index, :, :], inverse=True
)

B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index)

O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB)

case VelRepr.Mixed:

W_H_B = data.base_transform()
W_H_L = W_H_B @ B_H_L[link_index, :, :]
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :])

O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)

B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

with data.switch_velocity_representation(VelRepr.Mixed):
LW_v_WL = js.link.velocity(
model=model, data=data, link_index=link_index
)
LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3))

LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW)

case _:
raise ValueError(output_vel_repr)

# =============================================================
# Express the Jacobian derivative in the target representations
# =============================================================

# The derivative of the equation to change the input and output representations
# of the Jacobian derivative needs the computation of the plain link Jacobian.
# Compute here the full Jacobian of the model...
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
)

# ... and extract the link Jacobian using the boolean support body array.
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B

# Sum all the components that form the Jacobian derivative in the target
# input/output velocity representations.
O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ

return O_J̇_WL_I


@jax.jit
def bias_acceleration(
model: js.model.JaxSimModel,
Expand Down
6 changes: 5 additions & 1 deletion src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from .collidable_points import collidable_points_pos_vel
from .crba import crba
from .forward_kinematics import forward_kinematics, forward_kinematics_model
from .jacobian import jacobian, jacobian_full_doubly_left
from .jacobian import (
jacobian,
jacobian_derivative_full_doubly_left,
jacobian_full_doubly_left,
)
from .rnea import rnea
from .soft_contacts import SoftContacts, SoftContactsParams
119 changes: 118 additions & 1 deletion src/jaxsim/rbda/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.math import Adjoint, Cross

from . import utils

Expand Down Expand Up @@ -199,3 +199,120 @@ def compute_full_jacobian(
B_J_full_WL_B = J.squeeze().astype(float)

return B_J_full_WL_B, B_H_L


def jacobian_derivative_full_doubly_left(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
) -> tuple[jtp.Matrix, jtp.Array]:
r"""
Compute the derivative of the doubly-left full free-floating Jacobian of a model.

The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled.
It is useful to run the algorithm once, and then extract the link Jacobian
derivative by filtering the columns of the full Jacobian using the support
parent array :math:`\kappa(i)` of the link.

Args:
model: The model to consider.
joint_positions: The positions of the joints.
joint_velocities: The velocities of the joints.

Returns:
The derivative of the doubly-left full free-floating Jacobian of a model.
"""

_, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions, joint_velocities=joint_velocities
)

# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array

# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=jnp.eye(4)
)

# Allocate the buffer of 6D transform base -> link.
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
B_X_i = B_X_i.at[0].set(jnp.eye(6))

# Allocate the buffer of 6D transform derivatives base -> link.
B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))

# Allocate the buffer of the 6D link velocity in body-fixed representation.
B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6))

# Helper to compute the time derivative of the adjoint matrix.
def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:
return A_X_B @ Cross.vx(B_v_AB).squeeze()

# ============================================
# Compute doubly-left full Jacobian derivative
# ============================================

# Allocate the Jacobian matrix.
J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))

ComputeFullJacobianDerivativeCarry = tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
B_v_Bi,
B_X_i,
B_Ẋ_i,
J̇,
)

def compute_full_jacobian_derivative(
carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int
) -> tuple[ComputeFullJacobianDerivativeCarry, None]:

ii = i - 1
B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry

# Compute the base (0) to link (i) adjoint matrix.
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
B_X_i = B_X_i.at[i].set(B_Xi_i)

# Compute the body-fixed velocity of the link.
B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii]
B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi)

# Compute the base (0) to link (i) adjoint matrix derivative.
i_Xi_B = Adjoint.inverse(B_Xi_i)
B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi)
B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i)

# Compute the ii-th column of the B_Ṡ_BL(s) matrix.
B_Ṡii_BL = B_Ẋ_i[i] @ S[i]
J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze())

return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None

(_, B_X_i, B_Ẋ_i, J̇), _ = (
jax.lax.scan(
f=compute_full_jacobian_derivative,
init=compute_full_jacobian_derivative_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(_, B_X_i, B_Ẋ_i, J̇), None]
)

# Convert adjoints to SE(3) transforms.
# Returning them here prevents calling FK in case the output representation
# of the Jacobian needs to be changed.
B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)

# Adjust shape of doubly-left free-floating full Jacobian derivative.
B_J̇_full_WL_B = J̇.squeeze().astype(float)

return B_J̇_full_WL_B, B_H_L
40 changes: 40 additions & 0 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,43 @@ def test_link_bias_acceleration(
Jν_idt = kin_dyn.frame_bias_acc(frame_name=name)
Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index)
assert pytest.approx(Jν_idt) == Jν_js


def test_link_jacobian_derivative(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):

model = jaxsim_models_types

_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)

# =====
# Tests
# =====

# Get the generalized velocity.
I_ν = data.generalized_velocity()

# Compute J̇.
O_J̇_WL_I = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

# Compute the product J̇ν.
O_a_bias_WL = jax.vmap(
lambda link_index: js.link.bias_acceleration(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

# Compare the two computations.
assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(O_a_bias_WL)