Skip to content

Commit

Permalink
Update tests with the jit-compatible exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 18, 2024
1 parent 0d25785 commit 31a277e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
28 changes: 25 additions & 3 deletions tests/test_api_frame.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
import jaxlib.xla_extension
import pytest

import jaxsim.api as js
Expand All @@ -24,7 +25,7 @@ def test_frame_index(jaxsim_models_types: js.model.JaxSimModel):
assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_index
assert js.frame.idx_to_name(model=model, frame_index=frame_index) == frame_name
assert (
js.frame.idx_of_parent_link(model=model, frame_idx=frame_index)
js.frame.idx_of_parent_link(model=model, frame_index=frame_index)
< model.number_of_links()
)

Expand All @@ -44,6 +45,27 @@ def test_frame_index(jaxsim_models_types: js.model.JaxSimModel):
== model.frame_names()
)

with pytest.raises(ValueError):
_ = js.frame.name_to_idx(model=model, frame_name="non_existent_frame")

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=-1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=n_l - 1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=n_l + n_f)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=-1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=n_l - 1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=n_l + n_f)


def test_frame_transforms(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down Expand Up @@ -141,8 +163,8 @@ def test_frame_jacobians(

assert len(frame_indices) == len(frame_names)

for frame_name, frame_idx in zip(frame_names, frame_indices):
for frame_name, frame_index in zip(frame_names, frame_indices):

J_WL_js = js.frame.jacobian(model=model, data=data, frame_index=frame_idx)
J_WL_js = js.frame.jacobian(model=model, data=data, frame_index=frame_index)
J_WL_idt = kin_dyn.jacobian_frame(frame_name=frame_name)
assert J_WL_js == pytest.approx(J_WL_idt, abs=1e-9)
10 changes: 10 additions & 0 deletions tests/test_api_joint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax.numpy as jnp
import jaxlib.xla_extension
import pytest

import jaxsim.api as js
Expand Down Expand Up @@ -33,3 +34,12 @@ def test_joint_index(
)
== model.joint_names()
)

with pytest.raises(ValueError):
_ = js.joint.name_to_idx(model=model, joint_name="non_existent_joint")

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.joint.idx_to_name(model=model, joint_index=-1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.joint.idx_to_name(model=model, joint_index=model.number_of_joints())
10 changes: 10 additions & 0 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
import jaxlib.xla_extension
import pytest

import jaxsim.api as js
Expand Down Expand Up @@ -39,6 +40,15 @@ def test_link_index(
== model.link_names()
)

with pytest.raises(ValueError):
_ = js.link.name_to_idx(model=model, link_name="non_existent_link")

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.link.idx_to_name(model=model, link_index=-1)

with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):
_ = js.link.idx_to_name(model=model, link_index=model.number_of_links())


def test_link_inertial_properties(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down

0 comments on commit 31a277e

Please sign in to comment.