Skip to content

Commit

Permalink
Add test to check automatic differentiation of algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 8, 2024
1 parent 1546c37 commit 8952c1b
Showing 1 changed file with 393 additions and 0 deletions.
393 changes: 393 additions & 0 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,393 @@
import jax
import jax.numpy as jnp
from jax.test_util import check_grads

import jaxsim.api as js
from jaxsim import VelRepr


def get_random_data_and_references(
model: js.model.JaxSimModel,
velocity_representation: VelRepr,
key: jax.Array,
) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]:

key, subkey = jax.random.split(key, num=2)

data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)

key, subkey1, subkey2 = jax.random.split(key, num=3)

references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=velocity_representation,
)

# Remove the force applied to the base link if the model is fixed-base.
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)

return data, references


def test_ad_aba(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()
= data.joint_velocities(model=model)
xfb = data.state.physics_model.xfb()

# Inputs.
f = references.link_forces(model=model)
τ = references.joint_force_references(model=model)

# ====
# Test
# ====

import jaxsim.physics.algos.aba

# Get a closure exposing only the parameters to be differentiated.
aba = lambda xfb, s, , tau, f_ext: jaxsim.physics.algos.aba.aba(
model=model.physics_model, xfb=xfb, q=s, qd=, tau=tau, f_ext=f_ext
)

# Check both first-order and second-order derivatives.
check_grads(
f=aba,
args=(xfb, s, , τ, f),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_rnea(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()
= data.joint_velocities(model=model)
xfb = data.state.physics_model.xfb()

# Inputs.
f = references.link_forces(model=model)

# ====
# Test
# ====

import jaxsim.physics.algos.rnea

key, subkey1, subkey2 = jax.random.split(key, num=3)
W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1)
= jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1)

rnea = lambda xfb, s, , , W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea(
model=model.physics_model, xfb=xfb, q=s, qd=, qdd=, a0fb=W_v̇_WB, f_ext=f_ext
)

check_grads(
f=rnea,
args=(xfb, s, , , W_v̇_WB, f),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_crba(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()

# ====
# Test
# ====

import jaxsim.physics.algos.crba

# Get a closure exposing only the parameters to be differentiated.
crba = lambda s: jaxsim.physics.algos.crba.crba(model=model.physics_model, q=s)

# Check both first-order and second-order derivatives.
check_grads(
f=crba,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_fk(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()
xfb = data.state.physics_model.xfb()

# ====
# Test
# ====

import jaxsim.physics.algos.forward_kinematics

# Get a closure exposing only the parameters to be differentiated.
fk = (
lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
model=model.physics_model, xfb=xfb, q=s
)
)

# Check both first-order and second-order derivatives.
check_grads(
f=fk,
args=(xfb, s),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_jacobian(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()

# ====
# Test
# ====

import jaxsim.physics.algos.jacobian

# Get the link indices.
link_indices = js.link.names_to_idxs(model=model, link_names=model.link_names())

# Get a closure exposing only the parameters to be differentiated.
# We differentiate the jacobian of the last link, likely among those
# farther from the base.
jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian(
model=model.physics_model, q=s, body_index=link_indices[-1]
)

# Check both first-order and second-order derivatives.
check_grads(
f=jacobian,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_soft_contacts(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

key, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)
p = jax.random.uniform(subkey1, shape=(3,), minval=-1)
v = jax.random.uniform(subkey2, shape=(3,), minval=-1)
m = jax.random.uniform(subkey3, shape=(3,), minval=-1)

# Get the soft contacts parameters.
parameters = js.contact.estimate_good_soft_contacts_parameters(model=model)

# ====
# Test
# ====

import jaxsim.physics.algos.soft_contacts

# Get a closure exposing only the parameters to be differentiated.
soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
parameters=parameters
).contact_model(position=p, velocity=v, tangential_deformation=m)

# Check both first-order and second-order derivatives.
check_grads(
f=soft_contacts,
args=(p, v, m),
order=2,
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_integration(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel,
prng_key: jax.Array,
):

model = jaxsim_model_ergocub_reduced

key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=key
)

# Perturbation used for computing finite differences.
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# State in VelRepr.Inertial representation.
s = data.joint_positions()
= data.joint_velocities(model=model)
xfb = data.state.physics_model.xfb()
m = data.state.soft_contacts.tangential_deformation

# Inputs.
f = references.link_forces(model=model)
τ = references.joint_force_references(model=model)

# ====
# Test
# ====

import jaxsim.integrators

# Select a second-order Heun scheme with quaternion integrated on SO(3).
# Note that it's always preferable using the SO(3) versions on AD applications so
# that the gradient of the integrated dynamics always considers unary quaternions.
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model,
data=data,
system_dynamics=js.ode.system_dynamics,
),
)

# Initialize the integrator.
t0, dt = 0.0, 0.001
integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt)

# Function exposing only the parameters to be differentiated.
def step(
xfb: jax.typing.ArrayLike,
s: jax.typing.ArrayLike,
: jax.typing.ArrayLike,
m: jax.typing.ArrayLike,
tau: jax.typing.ArrayLike,
f_ext: jax.typing.ArrayLike,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:

data_x0 = data.replace(
state=js.ode.ODEState.build(
physics_model_state=js.ode.PhysicsModelState.build(
joint_positions=s,
joint_velocities=,
base_position=xfb[4:7],
base_quaternion=xfb[0:4],
base_linear_velocity=xfb[7:10],
base_angular_velocity=xfb[10:13],
),
soft_contacts_state=js.ode.SoftContactsState.build(
tangential_deformation=m
),
),
)

data_xf, _ = js.model.step(
dt=dt,
model=model,
data=data_x0,
integrator=integrator,
integrator_state=integrator_state,
joint_forces=tau,
external_forces=f_ext,
)

s_xf = data_xf.joint_positions()
ṡ_xf = data_xf.joint_velocities()
xfb_xf = data_xf.state.physics_model.xfb()
m_xf = data_xf.state.soft_contacts.tangential_deformation

return xfb_xf, s_xf, ṡ_xf, m_xf

# Check both first-order and second-order derivatives.
check_grads(
f=step,
args=(xfb, s, , m, τ, f),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

0 comments on commit 8952c1b

Please sign in to comment.