Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
- Use explicit joint type names
- Use `jnp.newaxis` instead of `None`
- Use dataclass for `JointType` instead of metaclass

Co-authored-by: Diego Ferigo <[email protected]>
  • Loading branch information
flferretti and diegoferigo committed May 15, 2024
1 parent 486def1 commit 4f142bc
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def joint_transforms_and_motion_subspaces(
# Extract the transforms and motion subspaces of the joints.
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
# for either the fixed or free-floating joint connecting the world to the base.
pre_H_suc = jnp.vstack([W_H_B[None, ...], pre_H_suc_J])
S = jnp.vstack([jnp.zeros((6, 1))[None, ...], S_J])
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])

# Extract the successor-to-child fixed transforms.
# Note that here we include also the index 0 since suc_H_child[0] stores the
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel:
# Static attributes
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple([j.axis for j in ordered_joints]),
)

Expand Down Expand Up @@ -265,9 +265,9 @@ def compute_P():
pre_H_suc, S = jax.lax.switch(
index=joint_type,
branches=(
compute_F, # JointType.F
compute_R, # JointType.R
compute_P, # JointType.P
compute_F, # JointType.Fixed
compute_R, # JointType.Revolute
compute_P, # JointType.Prismatic
),
)

Expand Down
32 changes: 6 additions & 26 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import dataclasses
from typing import Tuple, Union
from typing import ClassVar, Tuple, Union

import jax_dataclasses
import numpy as np
Expand All @@ -13,31 +13,11 @@
from .link import LinkDescription


class _JointTypeMeta(type):
def __new__(cls, name, bases, dct):
cls_instance = super().__new__(cls, name, bases, dct)

# Assign integer values to the descriptors
cls_instance.F = 0
cls_instance.R = 1
cls_instance.P = 2

return cls_instance


class JointType(metaclass=_JointTypeMeta):
"""
Type of supported joints.
"""

class F:
pass

class R:
pass

class P:
pass
@dataclasses.dataclass(frozen=True)
class JointType:
Fixed: ClassVar[int] = 0
Revolute: ClassVar[int] = 1
Prismatic: ClassVar[int] = 2


@jax_dataclasses.pytree_dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def build_model_description(
considered_joints=[
j.name
for j in sdf_data.joint_descriptions
if j.jtype is not descriptions.JointType.F
if j.jtype is not descriptions.JointType.Fixed
],
)

Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def joint_to_joint_type(
joint_type = joint.type

if joint_type == "fixed":
return descriptions.JointType.F
return descriptions.JointType.Fixed

if not (axis.xyz is not None and axis.xyz.xyz is not None):
raise ValueError("Failed to read axis xyz data")
Expand All @@ -86,10 +86,10 @@ def joint_to_joint_type(
axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)

if joint_type in {"revolute", "continuous"}:
return descriptions.JointType.R
return descriptions.JointType.Revolute

if joint_type == "prismatic":
return descriptions.JointType.P
return descriptions.JointType.Prismatic

raise ValueError("Joint not supported", axis_xyz, joint_type)

Expand Down

0 comments on commit 4f142bc

Please sign in to comment.