Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of the test suite #166

Merged
merged 6 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ profile = "black"
[tool.pytest.ini_options]
addopts = "-rsxX -v --strict-markers"
minversion = "6.0"
preview = true
testpaths = [
"tests",
]

target-version = "py311"

[tool.ruff]
exclude = [
".git",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_com_properties(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_data_joint_indexing(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_data_switch_velocity_representation(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_data_change_velocity_representation(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_api_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_frame_transforms(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_frame_transforms(

assert len(frame_indices) == len(frame_names)

for frame_name, frame_idx in zip(frame_names, frame_indices):
for frame_name in frame_names:

W_H_F_js = js.frame.transform(
model=model,
Expand All @@ -101,7 +101,7 @@ def test_frame_jacobians(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand All @@ -119,7 +119,7 @@ def test_frame_jacobians(

# Lower the number of frames for models with many frames.
if model.name().lower() == "ergocub":
assert any(["sole" in name for name in frame_names])
assert any("sole" in name for name in frame_names)
frame_names = [name for name in frame_names if "sole" in name]

# Get indices of frames.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_link_inertial_properties(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_link_transforms(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_link_jacobians(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_link_bias_acceleration(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
Expand Down
17 changes: 8 additions & 9 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_model_creation_and_reduction(

model_full = jaxsim_model_ergocub

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data_full = js.data.random_model_data(
model=model_full,
key=subkey,
Expand Down Expand Up @@ -77,15 +77,14 @@ def test_model_creation_and_reduction(
model_reduced = js.model.reduce(
model=model_full,
considered_joints=reduced_joints,
locked_joint_positions={
name: pos
for name, pos in zip(
locked_joint_positions=dict(
zip(
model_full.joint_names(),
data_full.joint_positions(
model=model_full, joint_names=model_full.joint_names()
).tolist(),
)
},
),
)

# Check DoFs.
Expand Down Expand Up @@ -156,7 +155,7 @@ def test_model_properties(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_model_rbda(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand Down Expand Up @@ -265,7 +264,7 @@ def test_model_jacobian(
# =====

# Create random references (joint torques and link forces)
key, subkey1, subkey2 = jax.random.split(key, num=3)
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
Expand Down Expand Up @@ -335,7 +334,7 @@ def test_model_fd_id_consistency(
# =====

# Create random references (joint torques and link forces)
key, subkey1, subkey2 = jax.random.split(key, num=3)
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
Expand Down
16 changes: 8 additions & 8 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_random_data_and_references(
model=model, key=subkey, velocity_representation=velocity_representation
)

key, subkey1, subkey2 = jax.random.split(key, num=3)
_, subkey1, subkey2 = jax.random.split(key, num=3)

references = js.references.JaxSimModelReferences.build(
model=model,
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_ad_aba(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_ad_rnea(
# Test
# ====

key, subkey1, subkey2 = jax.random.split(key, num=3)
_, subkey1, subkey2 = jax.random.split(key, num=3)
W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1)
s̈ = jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1)

Expand Down Expand Up @@ -178,7 +178,7 @@ def test_ad_crba(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_ad_fk(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_ad_jacobian(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_ad_soft_contacts(

model = jaxsim_models_types

key, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)
_, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)
p = jax.random.uniform(subkey1, shape=(3,), minval=-1)
v = jax.random.uniform(subkey2, shape=(3,), minval=-1)
m = jax.random.uniform(subkey3, shape=(3,), minval=-1)
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_ad_integration(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_collidable_point_jacobians(

model = jaxsim_models_types

key, subkey = jax.random.split(prng_key, num=2)
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
Expand Down
11 changes: 4 additions & 7 deletions tests/utils_idyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ def build_kindyncomputations_from_jaxsim_model(
removed_joint_positions = removed_joint_positions_default | (
removed_joint_positions
if removed_joint_positions is not None
else {
name: pos
for name, pos in zip(
else dict(
zip(
model.joint_names(),
data.joint_positions(model=model, joint_names=model.joint_names()),
)
}
)
)

# Create the KinDynComputations from the same URDF model.
Expand Down Expand Up @@ -127,9 +126,7 @@ def build(
urdf: pathlib.Path | str,
considered_joints: list[str] = None,
vel_repr: VelRepr = VelRepr.Inertial,
gravity: npt.NDArray = dataclasses.field(
default_factory=lambda: np.array([0, 0, -10.0])
),
gravity: npt.NDArray = np.array([0, 0, -10.0]),
flferretti marked this conversation as resolved.
Show resolved Hide resolved
removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None,
) -> KinDynComputations:

Expand Down