diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 1d164010b..27c352db9 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp +import jax.scipy.linalg import jaxlie import numpy as np @@ -337,6 +338,187 @@ def velocity( return O_J_WL_I @ I_ν +@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +def jacobian_derivative( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_index: jtp.IntLike, + output_vel_repr: VelRepr | None = None, +) -> jtp.Matrix: + """ + Compute the derivative of the free-floating jacobian of the link. + + Args: + model: The model to consider. + data: The data of the considered model. + link_index: The index of the link. + output_vel_repr: + The output velocity representation of the free-floating jacobian derivative. + + Returns: + The derivative of the 6×(6+n) free-floating jacobian of the link. + + Note: + The input representation of the free-floating jacobian derivative is the active + velocity representation. + """ + + output_vel_repr = ( + output_vel_repr if output_vel_repr is not None else data.velocity_representation + ) + + # Compute the derivative of the doubly-left free-floating full jacobian. + B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=data.joint_positions(), + joint_velocities=data.joint_velocities(), + ) + + # Compute the actual doubly-left free-floating jacobian derivative of the link + # by zeroing the columns not in the path π_B(L) using the boolean κ(i). + κb = model.kin_dyn_parameters.support_body_array_bool[link_index] + B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B + + # ===================================================== + # Compute quantities to adjust the input representation + # ===================================================== + + In = jnp.eye(model.dofs()) + On = jnp.zeros(shape=(model.dofs(), model.dofs())) + + match data.velocity_representation: + + case VelRepr.Inertial: + + W_H_B = data.base_transform() + B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) + + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_W, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + + case VelRepr.Body: + + B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation( + translation=jnp.zeros(3), rotation=jnp.eye(3) + ) + + B_Ẋ_B = jnp.zeros(shape=(6, 6)) + + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_B, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + + case VelRepr.Mixed: + + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_BW, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + + case _: + raise ValueError(data.velocity_representation) + + # ====================================================== + # Compute quantities to adjust the output representation + # ====================================================== + + match output_vel_repr: + + case VelRepr.Inertial: + + W_H_B = data.base_transform() + O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) + + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + + O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) + + case VelRepr.Body: + + O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform( + transform=B_H_L[link_index, :, :], inverse=True + ) + + B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) + + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) + + O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB) + + case VelRepr.Mixed: + + W_H_B = data.base_transform() + W_H_L = W_H_B @ B_H_L[link_index, :, :] + LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) + LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :]) + + O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B) + + B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) + + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + + with data.switch_velocity_representation(VelRepr.Mixed): + LW_v_WL = js.link.velocity( + model=model, data=data, link_index=link_index + ) + LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3)) + + LW_v_LW_L = LW_v_WL - LW_v_W_LW + LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L + + O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW) + + case _: + raise ValueError(output_vel_repr) + + # ============================================================= + # Express the Jacobian derivative in the target representations + # ============================================================= + + # The derivative of the equation to change the input and output representations + # of the Jacobian derivative needs the computation of the plain link Jacobian. + # Compute here the full Jacobian of the model... + B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=data.joint_positions(), + ) + + # ... and extract the link Jacobian using the boolean support body array. + B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B + + # Sum all the components that form the Jacobian derivative in the target + # input/output velocity representations. + O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs())) + O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T + O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T + O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ + + return O_J̇_WL_I + + @jax.jit def bias_acceleration( model: js.model.JaxSimModel, diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 1d8ace5b8..1e9b5a4ff 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -2,6 +2,10 @@ from .collidable_points import collidable_points_pos_vel from .crba import crba from .forward_kinematics import forward_kinematics, forward_kinematics_model -from .jacobian import jacobian, jacobian_full_doubly_left +from .jacobian import ( + jacobian, + jacobian_derivative_full_doubly_left, + jacobian_full_doubly_left, +) from .rnea import rnea from .soft_contacts import SoftContacts, SoftContactsParams diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index 8ac3a2264..531c921a2 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -4,7 +4,7 @@ import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.math import Adjoint +from jaxsim.math import Adjoint, Cross from . import utils @@ -199,3 +199,120 @@ def compute_full_jacobian( B_J_full_WL_B = J.squeeze().astype(float) return B_J_full_WL_B, B_H_L + + +def jacobian_derivative_full_doubly_left( + model: js.model.JaxSimModel, + *, + joint_positions: jtp.VectorLike, + joint_velocities: jtp.VectorLike, +) -> tuple[jtp.Matrix, jtp.Array]: + r""" + Compute the derivative of the doubly-left full free-floating Jacobian of a model. + + The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled. + It is useful to run the algorithm once, and then extract the link Jacobian + derivative by filtering the columns of the full Jacobian using the support + parent array :math:`\kappa(i)` of the link. + + Args: + model: The model to consider. + joint_positions: The positions of the joints. + joint_velocities: The velocities of the joints. + + Returns: + The derivative of the doubly-left full free-floating Jacobian of a model. + """ + + _, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs( + model=model, joint_positions=joint_positions, joint_velocities=joint_velocities + ) + + # Get the parent array λ(i). + # Note: λ(0) must not be used, it's initialized to -1. + λ = model.kin_dyn_parameters.parent_array + + # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # These transforms define the relative kinematics of the entire model, including + # the base transform for both floating-base and fixed-base models. + i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=s, base_transform=jnp.eye(4) + ) + + # Allocate the buffer of 6D transform base -> link. + B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) + B_X_i = B_X_i.at[0].set(jnp.eye(6)) + + # Allocate the buffer of 6D transform derivatives base -> link. + B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) + + # Allocate the buffer of the 6D link velocity in body-fixed representation. + B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6)) + + # Helper to compute the time derivative of the adjoint matrix. + def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix: + return A_X_B @ Cross.vx(B_v_AB).squeeze() + + # ============================================ + # Compute doubly-left full Jacobian derivative + # ============================================ + + # Allocate the Jacobian matrix. + J̇ = jnp.zeros(shape=(6, 6 + model.dofs())) + + ComputeFullJacobianDerivativeCarry = tuple[ + jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax + ] + + compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = ( + B_v_Bi, + B_X_i, + B_Ẋ_i, + J̇, + ) + + def compute_full_jacobian_derivative( + carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int + ) -> tuple[ComputeFullJacobianDerivativeCarry, None]: + + ii = i - 1 + B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry + + # Compute the base (0) to link (i) adjoint matrix. + B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i]) + B_X_i = B_X_i.at[i].set(B_Xi_i) + + # Compute the body-fixed velocity of the link. + B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii] + B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi) + + # Compute the base (0) to link (i) adjoint matrix derivative. + i_Xi_B = Adjoint.inverse(B_Xi_i) + B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi) + B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i) + + # Compute the ii-th column of the B_Ṡ_BL(s) matrix. + B_Ṡii_BL = B_Ẋ_i[i] @ S[i] + J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze()) + + return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None + + (_, B_X_i, B_Ẋ_i, J̇), _ = ( + jax.lax.scan( + f=compute_full_jacobian_derivative, + init=compute_full_jacobian_derivative_carry, + xs=np.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(_, B_X_i, B_Ẋ_i, J̇), None] + ) + + # Convert adjoints to SE(3) transforms. + # Returning them here prevents calling FK in case the output representation + # of the Jacobian needs to be changed. + B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i) + + # Adjust shape of doubly-left free-floating full Jacobian derivative. + B_J̇_full_WL_B = J̇.squeeze().astype(float) + + return B_J̇_full_WL_B, B_H_L diff --git a/tests/test_api_link.py b/tests/test_api_link.py index dc619fd80..8816cc46e 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -211,3 +211,43 @@ def test_link_bias_acceleration( Jν_idt = kin_dyn.frame_bias_acc(frame_name=name) Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index) assert pytest.approx(Jν_idt) == Jν_js + + +def test_link_jacobian_derivative( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + _, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, + key=subkey, + velocity_representation=velocity_representation, + ) + + # ===== + # Tests + # ===== + + # Get the generalized velocity. + I_ν = data.generalized_velocity() + + # Compute J̇. + O_J̇_WL_I = jax.vmap( + lambda link_index: js.link.jacobian_derivative( + model=model, data=data, link_index=link_index + ) + )(js.link.names_to_idxs(model=model, link_names=model.link_names())) + + # Compute the product J̇ν. + O_a_bias_WL = jax.vmap( + lambda link_index: js.link.bias_acceleration( + model=model, data=data, link_index=link_index + ) + )(js.link.names_to_idxs(model=model, link_names=model.link_names())) + + # Compare the two computations. + assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(O_a_bias_WL)