Skip to content

Commit

Permalink
Test J̇ computation against AD
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 10, 2024
1 parent cad214b commit 4e7f13b
Showing 1 changed file with 64 additions and 1 deletion.
65 changes: 64 additions & 1 deletion tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -250,4 +251,66 @@ def test_link_jacobian_derivative(
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

# Compare the two computations.
assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(O_a_bias_WL)
assert jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν) == pytest.approx(
O_a_bias_WL, abs=1e-9
)

# Compute the plain Jacobian.
# This function will be used to compute the Jacobian derivative with AD.
# Given q, computing J̇ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing J.
def J(q) -> jax.Array:

data_ad = js.data.JaxSimModelData.zero(
model=model, velocity_representation=data.velocity_representation
)

data_ad = data_ad.reset_base_position(base_position=q[:3])
data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7])
data_ad = data_ad.reset_joint_positions(positions=q[7:])

O_J_WL_I = js.model.generalized_free_floating_jacobian(
model=model, data=data_ad
)

return O_J_WL_I

def compute_q(data: js.data.JaxSimModelData) -> jax.Array:

q = jnp.hstack(
[data.base_position(), data.base_orientation(), data.joint_positions()]
)

return q

def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:

with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity()[3:6]

with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity()[0:3]

W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation(),
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

= jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities()])

return

# Compute q and q̇.
q = compute_q(data)
= compute_q̇(data)

# Compute dJ/dt with AD.
dJ_dq = jax.jacfwd(J, argnums=0)(q)
O_J̇_ad_WL_I = jnp.einsum("ijkq,q->ijk", dJ_dq, )

assert O_J̇_ad_WL_I == pytest.approx(O_J̇_WL_I)
assert jnp.einsum("l6g,g->l6", O_J̇_ad_WL_I, I_ν) == pytest.approx(
jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν)
)

0 comments on commit 4e7f13b

Please sign in to comment.