Skip to content

Commit

Permalink
Add new RBDA to compute the doubly-left full Jacobian derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 5, 2024
1 parent 44b9aee commit cbe6c1f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from .collidable_points import collidable_points_pos_vel
from .crba import crba
from .forward_kinematics import forward_kinematics, forward_kinematics_model
from .jacobian import jacobian, jacobian_full_doubly_left
from .jacobian import (
jacobian,
jacobian_derivative_full_doubly_left,
jacobian_full_doubly_left,
)
from .rnea import rnea
from .soft_contacts import SoftContacts, SoftContactsParams
119 changes: 118 additions & 1 deletion src/jaxsim/rbda/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.math import Adjoint, Cross

from . import utils

Expand Down Expand Up @@ -199,3 +199,120 @@ def compute_full_jacobian(
B_J_full_WL_B = J.squeeze().astype(float)

return B_J_full_WL_B, B_H_L


def jacobian_derivative_full_doubly_left(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
) -> tuple[jtp.Matrix, jtp.Array]:
r"""
Compute the derivative of the doubly-left full free-floating Jacobian of a model.
The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled.
It is useful to run the algorithm once, and then extract the link Jacobian
derivative by filtering the columns of the full Jacobian using the support
parent array :math:`\kappa(i)` of the link.
Args:
model: The model to consider.
joint_positions: The positions of the joints.
joint_velocities: The velocities of the joints.
Returns:
The derivative of the doubly-left full free-floating Jacobian of a model.
"""

_, _, s, _, , _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions, joint_velocities=joint_velocities
)

# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array

# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=jnp.eye(4)
)

# Allocate the buffer of 6D transform base -> link.
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
B_X_i = B_X_i.at[0].set(jnp.eye(6))

# Allocate the buffer of 6D transform derivatives base -> link.
B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))

# Allocate the buffer of the 6D link velocity in body-fixed representation.
B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6))

# Helper to compute the time derivative of the adjoint matrix.
def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:
return A_X_B @ Cross.vx(B_v_AB).squeeze()

# ============================================
# Compute doubly-left full Jacobian derivative
# ============================================

# Allocate the Jacobian matrix.
= jnp.zeros(shape=(6, 6 + model.dofs()))

ComputeFullJacobianDerivativeCarry = tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
B_v_Bi,
B_X_i,
B_Ẋ_i,
,
)

def compute_full_jacobian_derivative(
carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int
) -> tuple[ComputeFullJacobianDerivativeCarry, None]:

ii = i - 1
B_v_Bi, B_X_i, B_Ẋ_i, = carry

# Compute the base (0) to link (i) adjoint matrix.
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
B_X_i = B_X_i.at[i].set(B_Xi_i)

# Compute the body-fixed velocity of the link.
B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * [ii]
B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi)

# Compute the base (0) to link (i) adjoint matrix derivative.
i_Xi_B = Adjoint.inverse(B_Xi_i)
B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi)
B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i)

# Compute the ii-th column of the B_Ṡ_BL(s) matrix.
B_Ṡii_BL = B_Ẋ_i[i] @ S[i]
= .at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze())

return (B_v_Bi, B_X_i, B_Ẋ_i, ), None

(_, B_X_i, B_Ẋ_i, ), _ = (
jax.lax.scan(
f=compute_full_jacobian_derivative,
init=compute_full_jacobian_derivative_carry,
xs=model.kin_dyn_parameters.link_indices[1:],
)
if model.number_of_links() > 1
else [(_, B_X_i, B_Ẋ_i, ), None]
)

# Convert adjoints to SE(3) transforms.
# Returning them here prevents calling FK in case the output representation
# of the Jacobian needs to be changed.
B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)

# Adjust shape of doubly-left free-floating full Jacobian derivative.
B_J̇_full_WL_B = .squeeze().astype(float)

return B_J̇_full_WL_B, B_H_L

0 comments on commit cbe6c1f

Please sign in to comment.