Skip to content

Commit

Permalink
Update AD test
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 14, 2024
1 parent e2e6c0e commit 063e866
Showing 1 changed file with 8 additions and 21 deletions.
29 changes: 8 additions & 21 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax.test_util import check_grads

import jaxsim.api as js
import jaxsim.rbda
from jaxsim import VelRepr

# All JaxSim algorithms, excluding the variable-step integrators, should support
Expand Down Expand Up @@ -77,10 +78,8 @@ def test_ad_aba(
# Test
# ====

import jaxsim.physics.algos.aba

# Get a closure exposing only the parameters to be differentiated.
aba = lambda xfb, s, , tau, f_ext: jaxsim.physics.algos.aba.aba(
aba = lambda xfb, s, , tau, f_ext: jaxsim.rbda.aba(
model=model.physics_model, xfb=xfb, q=s, qd=, tau=tau, f_ext=f_ext
)

Expand Down Expand Up @@ -121,14 +120,12 @@ def test_ad_rnea(
# Test
# ====

import jaxsim.physics.algos.rnea

key, subkey1, subkey2 = jax.random.split(key, num=3)
W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1)
= jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1)

# Get a closure exposing only the parameters to be differentiated.
rnea = lambda xfb, s, , , W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea(
rnea = lambda xfb, s, , , W_v̇_WB, f_ext: jaxsim.rbda.rnea(
model=model.physics_model, xfb=xfb, q=s, qd=, qdd=, a0fb=W_v̇_WB, f_ext=f_ext
)

Expand Down Expand Up @@ -164,10 +161,8 @@ def test_ad_crba(
# Test
# ====

import jaxsim.physics.algos.crba

# Get a closure exposing only the parameters to be differentiated.
crba = lambda s: jaxsim.physics.algos.crba.crba(model=model.physics_model, q=s)
crba = lambda s: jaxsim.rbda.crba(model=model.physics_model, q=s)

# Check derivatives against finite differences.
check_grads(
Expand Down Expand Up @@ -202,13 +197,9 @@ def test_ad_fk(
# Test
# ====

import jaxsim.physics.algos.forward_kinematics

# Get a closure exposing only the parameters to be differentiated.
fk = (
lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
model=model.physics_model, xfb=xfb, q=s
)
fk = lambda xfb, s: jaxsim.rbda.forward_kinematics_model(
model=model.physics_model, xfb=xfb, q=s
)

# Check derivatives against finite differences.
Expand Down Expand Up @@ -243,15 +234,13 @@ def test_ad_jacobian(
# Test
# ====

import jaxsim.physics.algos.jacobian

# Get the link indices.
link_indices = js.link.names_to_idxs(model=model, link_names=model.link_names())

# Get a closure exposing only the parameters to be differentiated.
# We differentiate the jacobian of the last link, likely among those
# farther from the base.
jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian(
jacobian = lambda s: jaxsim.rbda.jacobian(
model=model.physics_model, q=s, body_index=link_indices[-1]
)

Expand Down Expand Up @@ -287,10 +276,8 @@ def test_ad_soft_contacts(
# Test
# ====

import jaxsim.physics.algos.soft_contacts

# Get a closure exposing only the parameters to be differentiated.
soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
soft_contacts = lambda p, v, m: jaxsim.rbda.SoftContacts(
parameters=parameters
).contact_model(position=p, velocity=v, tangential_deformation=m)

Expand Down

0 comments on commit 063e866

Please sign in to comment.