Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate link, frame, and joint indices in our jit-compiled APIs #182

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions

from .common import VelRepr

Expand All @@ -17,22 +18,32 @@
# =======================


@jax.jit
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
) -> jtp.Int:
"""
Get the index of the link to which the frame is rigidly attached.

Args:
model: The model to consider.
frame_idx: The index of the frame.
frame_index: The index of the frame.

Returns:
The index of the frame's parent link.
"""

n_l = model.number_of_links()
n_f = len(model.frame_names())

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)

return model.kin_dyn_parameters.frame_parameters.body[
frame_idx - model.number_of_links()
frame_index - model.number_of_links()
]


Expand All @@ -49,19 +60,18 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
The index of the frame.
"""

if frame_name in model.kin_dyn_parameters.frame_parameters.name:
return (
jnp.array(
np.argwhere(
np.array(model.kin_dyn_parameters.frame_parameters.name)
== frame_name
)
)
.squeeze()
.astype(int)
) + model.number_of_links()
if frame_name not in model.kin_dyn_parameters.frame_parameters.name:
raise ValueError(f"Frame '{frame_name}' not found in the model.")

return jnp.array(-1).astype(int)
return (
jnp.array(
np.argwhere(
np.array(model.kin_dyn_parameters.frame_parameters.name) == frame_name
)
)
.astype(int)
.squeeze()
) + model.number_of_links()
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved


def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
Expand All @@ -76,6 +86,15 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
The name of the frame.
"""

n_l = model.number_of_links()
n_f = len(model.frame_names())

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)

return model.kin_dyn_parameters.frame_parameters.name[
frame_index - model.number_of_links()
]
Expand Down Expand Up @@ -142,8 +161,17 @@ def transform(
The 4x4 matrix representing the transform.
"""

n_l = model.number_of_links()
n_f = len(model.frame_names())

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)

# Compute the necessary transforms.
L = idx_of_parent_link(model=model, frame_idx=frame_index)
L = idx_of_parent_link(model=model, frame_index=frame_index)
W_H_L = js.link.transform(model=model, data=data, link_index=L)

# Get the static frame pose wrt the parent link.
Expand Down Expand Up @@ -181,12 +209,21 @@ def jacobian(
velocity representation.
"""

n_l = model.number_of_links()
n_f = len(model.frame_names())

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Get the index of the parent link.
L = idx_of_parent_link(model=model, frame_idx=frame_index)
L = idx_of_parent_link(model=model, frame_index=frame_index)

# Compute the Jacobian of the parent link using body-fixed output representation.
L_J_WL = js.link.jacobian(
Expand Down
35 changes: 28 additions & 7 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions

# =======================
# Index-related functions
Expand All @@ -25,14 +26,18 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
The index of the joint.
"""

if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
# Note: the index of the joint for RBDAs starts from 1, but
# the index for accessing the right element starts from 0.
# Therefore, there is a -1.
return jnp.array(
if joint_name not in model.joint_names():
raise ValueError(f"Joint '{joint_name}' not found in the model.")

# Note: the index of the joint for RBDAs starts from 1, but the index for
# accessing the right element starts from 0. Therefore, there is a -1.
return (
jnp.array(
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
).squeeze()
return jnp.array(-1).astype(int)
)
.astype(int)
.squeeze()
)


def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
Expand All @@ -47,6 +52,14 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
The name of the joint.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
msg="Invalid joint index '{idx}'",
idx=joint_index,
)

return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]


Expand Down Expand Up @@ -112,6 +125,14 @@ def position_limit(
if model.number_of_joints() <= 1:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)

exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
msg="Invalid joint index '{idx}'",
idx=joint_index,
)

s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]

Expand Down
82 changes: 74 additions & 8 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import exceptions

from .common import VelRepr

Expand All @@ -31,15 +32,16 @@ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
The index of the link.
"""

if link_name in model.kin_dyn_parameters.link_names:
return (
jnp.array(
np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
)
.squeeze()
.astype(int)
if link_name not in model.link_names():
raise ValueError(f"Link '{link_name}' not found in the model.")

return (
jnp.array(
np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
)
return jnp.array(-1).astype(int)
.astype(int)
.squeeze()
)
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved


def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
Expand All @@ -54,6 +56,14 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
The name of the link.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

return model.kin_dyn_parameters.link_names[link_index]


Expand Down Expand Up @@ -112,6 +122,14 @@ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
The mass of the link.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)


Expand All @@ -131,6 +149,14 @@ def spatial_inertia(
the link frame (body-fixed representation).
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

link_parameters = jax.tree_util.tree_map(
lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
)
Expand All @@ -157,6 +183,14 @@ def transform(
The 4x4 matrix representing the transform.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

return js.model.forward_kinematics(model=model, data=data)[link_index]


Expand Down Expand Up @@ -230,6 +264,14 @@ def jacobian(
velocity representation.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
Expand Down Expand Up @@ -318,6 +360,14 @@ def velocity(
The 6D velocity of the link in the specified velocity representation.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
Expand Down Expand Up @@ -364,6 +414,14 @@ def jacobian_derivative(
velocity representation.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
Expand Down Expand Up @@ -538,6 +596,14 @@ def bias_acceleration(
The 6D bias acceleration of the link.
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)

# Compute the bias acceleration of all links in the active representation.
O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
return O_v̇_WL
Loading