Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tests to use the functional API #102

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def inertial_to_other_representation(
transform: jtp.Matrix,
is_force: bool = False,
) -> jtp.Array:
"""
r"""
Convert a 6D quantity from inertial-fixed to another representation.

Args:
Expand Down Expand Up @@ -144,7 +144,7 @@ def other_representation_to_inertial(
transform: jtp.Matrix,
is_force: bool = False,
) -> jtp.Array:
"""
r"""
Convert a 6D quantity from another representation to inertial-fixed.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def base_velocity(self) -> jtp.Vector:

@jax.jit
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
"""
r"""
Get the generalized position
:math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.

Expand All @@ -408,7 +408,7 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:

@jax.jit
def generalized_velocity(self) -> jtp.Vector:
"""
r"""
Get the generalized velocity
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_vl_WC):
def free_floating_gravity_forces(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
r"""
Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.

Args:
Expand Down Expand Up @@ -914,7 +914,7 @@ def free_floating_gravity_forces(
def free_floating_bias_forces(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
r"""
Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
of the model.

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def integrate_rk_stage(
def post_process_state(
cls, x0: State, t0: Time, xf: NextState, dt: TimeStep
) -> NextState:
"""
r"""
Post-process the integrated state at :math:`t_f = t_0 + \Delta t`.

Args:
Expand Down
42 changes: 19 additions & 23 deletions tests/test_ad_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jax.test_util import check_grads
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model
import jaxsim.api as js
from jaxsim import VelRepr

from . import utils_models, utils_rng
from .utils_models import Robot
Expand All @@ -31,43 +31,37 @@ def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Get the URDF of the robot
urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot)

# Build the high-level model
model = Model.build_from_model_description(
# Build the model
model = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_file_path,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=True,
).mutable(mutable=True, validate=True)
gravity=gravity,
)

random_state = utils_rng.random_model_state(model=model)

# Initialize the model with a random state
model.data.model_state = utils_rng.random_physics_model_state(
physics_model=model.physics_model
data = js.data.JaxSimModelData.build(
model=model, velocity_representation=vel_repr, **random_state
)

# Initialize the model with a random input
model.data.model_input = utils_rng.random_physics_model_input(
physics_model=model.physics_model
)
tau, f_ext = utils_rng.random_model_input(model=model)

# ========================
# Extract state and inputs
# ========================

# Extract the physics model used in the low-level physics algorithms
physics_model = model.physics_model

# State
s = model.joint_positions()
ṡ = model.joint_velocities()
xfb = model.data.model_state.xfb()

# Inputs
f_ext = model.external_forces()
tau = model.joint_generalized_forces_targets()
s = data.joint_positions(model=model)
ṡ = data.joint_velocities(model=model)
xfb = data.state.physics_model.xfb()

# Perturbation used for computing finite differences
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

physics_model = model.physics_model

# =====================================================
# Check first-order and second-order derivatives of ABA
# =====================================================
Expand Down Expand Up @@ -149,7 +143,9 @@ def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None:

import jaxsim.physics.algos.jacobian

link_indices = [l.index() for l in model.links()]
link_indices = [
js.link.name_to_idx(model=model, link_name=link) for link in model.link_names()
]

jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian(
model=physics_model, q=s, body_index=link_indices[-1]
Expand Down
69 changes: 38 additions & 31 deletions tests/test_eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model
import jaxsim.api as js
from jaxsim import VelRepr

from . import utils_idyntree, utils_models, utils_rng
from .utils_models import Robot
Expand Down Expand Up @@ -38,26 +38,27 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Get the URDF of the robot
urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot)

# Build the high-level model
model_jaxsim = Model.build_from_model_description(
# Build the model
model_jaxsim = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_file_path,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=True,
).mutable(mutable=True, validate=True)
gravity=gravity,
)

random_state = utils_rng.random_model_state(model=model_jaxsim)

# Initialize the model with a random state
model_jaxsim.data.model_state = utils_rng.random_physics_model_state(
physics_model=model_jaxsim.physics_model
data = js.data.JaxSimModelData.build(
model=model_jaxsim, velocity_representation=vel_repr, **random_state
)

# Initialize the model with a random input
model_jaxsim.data.model_input = utils_rng.random_physics_model_input(
physics_model=model_jaxsim.physics_model
)
tau, f_ext = utils_rng.random_model_input(model=model_jaxsim)

# Get the joint torques
tau = model_jaxsim.joint_generalized_forces_targets()
link_indices = [
js.link.name_to_idx(model=model_jaxsim, link_name=link)
for link in model_jaxsim.link_names()
]

# ==========================
# Ground truth with iDynTree
Expand All @@ -71,19 +72,19 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
)

kin_dyn.set_robot_state(
joint_positions=np.array(model_jaxsim.joint_positions()),
joint_velocities=np.array(model_jaxsim.joint_velocities()),
base_transform=np.array(model_jaxsim.base_transform()),
base_velocity=np.array(model_jaxsim.base_velocity()),
joint_positions=np.array(data.joint_positions()),
joint_velocities=np.array(data.joint_velocities()),
base_transform=np.array(data.base_transform()),
base_velocity=np.array(data.base_velocity()),
)

assert kin_dyn.joint_names() == list(model_jaxsim.joint_names())
assert kin_dyn.gravity == pytest.approx(model_jaxsim.physics_model.gravity[0:3])
assert kin_dyn.joint_positions() == pytest.approx(model_jaxsim.joint_positions())
assert kin_dyn.joint_velocities() == pytest.approx(model_jaxsim.joint_velocities())
assert kin_dyn.base_velocity() == pytest.approx(model_jaxsim.base_velocity())
assert kin_dyn.frame_transform(model_jaxsim.base_frame()) == pytest.approx(
model_jaxsim.base_transform()
assert kin_dyn.gravity == pytest.approx(data.gravity[0:3])
assert kin_dyn.joint_positions() == pytest.approx(data.joint_positions())
assert kin_dyn.joint_velocities() == pytest.approx(data.joint_velocities())
assert kin_dyn.base_velocity() == pytest.approx(data.base_velocity())
assert kin_dyn.frame_transform(model_jaxsim.base_link()) == pytest.approx(
data.base_transform()
)

M_idt = kin_dyn.mass_matrix()
Expand All @@ -101,10 +102,15 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Test individual terms of the EoM
# ================================

M_jaxsim = model_jaxsim.free_floating_mass_matrix()
g_jaxsim = model_jaxsim.free_floating_gravity_forces()
J_jaxsim = jnp.vstack([link.jacobian() for link in model_jaxsim.links()])
h_jaxsim = model_jaxsim.free_floating_bias_forces()
M_jaxsim = js.model.free_floating_mass_matrix(model=model_jaxsim, data=data)
g_jaxsim = js.model.free_floating_gravity_forces(model=model_jaxsim, data=data)
J_jaxsim = jnp.vstack(
[
js.link.jacobian(model=model_jaxsim, data=data, link_index=idx)
for idx in link_indices
]
)
h_jaxsim = js.model.free_floating_bias_forces(model=model_jaxsim, data=data)

# Support both fixed-base and floating-base models by slicing the first six rows
sl = np.s_[0:] if model_jaxsim.floating_base() else np.s_[6:]
Expand All @@ -118,9 +124,10 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Test the forward dynamics computed with CRB
# ===========================================

J_ff = model_jaxsim.generalized_free_floating_jacobian()
f_ext = model_jaxsim.external_forces().flatten()
ν̇ = np.hstack(model_jaxsim.forward_dynamics_crb(tau=tau))
J_ff = js.model.generalized_free_floating_jacobian(model=model_jaxsim, data=data)
ν̇ = np.hstack(
js.model.forward_dynamics_crb(model=model_jaxsim, data=data, joint_forces=tau)
)
S = np.block(
[np.zeros(shape=(model_jaxsim.dofs(), 6)), np.eye(model_jaxsim.dofs())]
).T
Expand Down
38 changes: 21 additions & 17 deletions tests/test_forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import pytest
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model
import jaxsim.api as js
from jaxsim import VelRepr

from . import utils_models, utils_rng
from .utils_models import Robot
Expand Down Expand Up @@ -37,35 +37,39 @@ def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Get the URDF of the robot
urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot)

# Build the high-level model
model = Model.build_from_model_description(
# Build the model
model = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_file_path,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=True,
).mutable(mutable=True, validate=True)
gravity=gravity,
)

random_state = utils_rng.random_model_state(model=model)

# Initialize the model with a random state
model.data.model_state = utils_rng.random_physics_model_state(
physics_model=model.physics_model
data = js.data.JaxSimModelData.build(
model=model, velocity_representation=vel_repr, **random_state
)

# Initialize the model with a random input
model.data.model_input = utils_rng.random_physics_model_input(
physics_model=model.physics_model
)

# Get the joint torques
tau = model.joint_generalized_forces_targets()
tau, _ = utils_rng.random_model_input(model=model)

# Compute model acceleration with ABA
v̇_WB_aba, s̈_aba = model.forward_dynamics_aba(tau=tau)
v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=tau,
)

# ==============================================
# Compute forward dynamics with dedicated method
# ==============================================

v̇_WB, s̈ = model.forward_dynamics_crb(tau=tau)
v̇_WB, s̈ = js.model.forward_dynamics_crb(
model=model,
data=data,
joint_forces=tau,
)

assert s̈.squeeze() == pytest.approx(s̈_aba.squeeze(), abs=0.5)
assert v̇_WB.squeeze() == pytest.approx(v̇_WB_aba.squeeze(), abs=0.2)
Loading
Loading