Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reducing a model considering non-zero positions of removed joints #137

Merged
merged 17 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- isort
- pre-commit
# [testing]
- idyntree
- idyntree >= 12.2.1
- pytest
- pytest-icdiff
- robot_descriptions
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ style =
isort
pre-commit
testing =
idyntree
idyntree >= 12.2.1
pytest >=6.0
pytest-icdiff
robot-descriptions
Expand Down
49 changes: 38 additions & 11 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import copy
import dataclasses
import functools
import pathlib
from typing import Any
from typing import Any, Sequence

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -55,7 +56,7 @@ def build_from_model_description(
*,
terrain: jaxsim.terrain.Terrain | None = None,
is_urdf: bool | None = None,
considered_joints: list[str] | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
"""
Build a Model object from a model description.
Expand Down Expand Up @@ -257,24 +258,50 @@ def link_names(self) -> tuple[str, ...]:
# =====================


def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
def reduce(
model: JaxSimModel,
considered_joints: tuple[str, ...],
locked_joint_positions: dict[str, jtp.Float] | None = None,
) -> JaxSimModel:
"""
Reduce the model by lumping together the links connected by removed joints.

Args:
model: The model to reduce.
considered_joints: The sequence of joints to consider.

Note:
If considered_joints contains joints not existing in the model, the method
will raise an exception. If considered_joints is empty, the method will
return a copy of the input model.
locked_joint_positions:
A dictionary containing the positions of the joints to be considered
in the reduction process. The removed joints in the reduced model
will have their position locked to their value in this dictionary.
If a joint is not part of the dictionary, its position is set to zero.
"""

locked_joint_positions = (
locked_joint_positions if locked_joint_positions is not None else {}
)

# If locked joints are passed, make sure that they are valid.
if not set(locked_joint_positions).issubset(model.joint_names()):
new_joints = set(model.joint_names()) - set(locked_joint_positions)
raise ValueError(f"Passed joints not existing in the model: {new_joints}")

# Copy the model description with a deep copy of the joints.
intermediate_description = dataclasses.replace(
model.description.get(), joints=copy.deepcopy(model.description.get().joints)
)

# Update the initial position of the joints.
# This is necessary to compute the correct pose of the link pairs connected
# to removed joints.
for joint_name in set(model.joint_names()) - set(considered_joints):
j = intermediate_description.joints_dict[joint_name]
with j.mutable_context():
j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))

# Reduce the model description.
# If considered_joints contains joints not existing in the model, the method
# will raise an exception.
reduced_intermediate_description = model.description.obj.reduce(
# If `considered_joints` contains joints not existing in the model,
# the method will raise an exception.
reduced_intermediate_description = intermediate_description.reduce(
considered_joints=list(considered_joints)
)

Expand Down
73 changes: 11 additions & 62 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
JointType,
ModelDescription,
)
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms

from .rotation import Rotation

Expand Down Expand Up @@ -87,21 +88,19 @@ def build(description: ModelDescription) -> JointModel:
# w.r.t. the implicit __model__ SDF frame is not the identity).
suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)

# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=description)

# Compute the parent-to-predecessor and successor-to-child transforms for
# each joint belonging to the model.
# Note that the joint indices starts from i=1 given our joint model,
# therefore the entries at index 0 are not updated.
for joint in ordered_joints:
λ_H_pre = λ_H_pre.at[joint.index].set(
description.relative_transform(
relative_to=joint.parent.name,
name=joint.name,
)
fk.relative_transform(relative_to=joint.parent.name, name=joint.name)
)
suc_H_i = suc_H_i.at[joint.index].set(
description.relative_transform(
relative_to=joint.name, name=joint.child.name
)
fk.relative_transform(relative_to=joint.name, name=joint.child.name)
)

# Define the DoFs of the base link.
Expand Down Expand Up @@ -243,16 +242,16 @@ def supported_joint_motion(
"""

if isinstance(joint_type, JointType):
code = joint_type
type_enum = joint_type
elif isinstance(joint_type, JointDescriptor):
code = joint_type.code
type_enum = joint_type.joint_type
else:
raise ValueError(joint_type)

# Prepare the joint position
s = jnp.array(joint_position).astype(float)

match code:
match type_enum:

case JointType.R:
joint_type: JointGenericAxis
Expand All @@ -276,58 +275,8 @@ def supported_joint_motion(
S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))

case JointType.F:
raise ValueError("Fixed joints shouldn't be here")

case JointType.Rx:

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_x_radians(theta=s)
)

S = jnp.vstack([0, 0, 0, 1.0, 0, 0])

case JointType.Ry:

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_y_radians(theta=s)
)

S = jnp.vstack([0, 0, 0, 0, 1.0, 0])

case JointType.Rz:

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_z_radians(theta=s)
)

S = jnp.vstack([0, 0, 0, 0, 0, 1.0])

case JointType.Px:

pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array([s, 0.0, 0.0]),
)

S = jnp.vstack([1.0, 0, 0, 0, 0, 0])

case JointType.Py:

pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array([0.0, s, 0.0]),
)

S = jnp.vstack([0, 1.0, 0, 0, 0, 0])

case JointType.Pz:

pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array([0.0, 0.0, s]),
)

S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
pre_H_suc = jaxlie.SE3.identity()
S = jnp.zeros(shape=(6, 1))

case _:
raise ValueError(joint_type)
Expand Down
77 changes: 30 additions & 47 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import enum
from typing import Tuple, Union
Expand All @@ -6,79 +8,60 @@
import numpy as np
import numpy.typing as npt

import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass, Mutability

from .link import LinkDescription


@enum.unique
class JointType(enum.IntEnum):
"""
Enumeration of joint types for robot joints.

Args:
F: Fixed joint (no movement).
R: Revolute joint (rotation).
P: Prismatic joint (translation).
Rx: Revolute joint with rotation about the X-axis.
Ry: Revolute joint with rotation about the Y-axis.
Rz: Revolute joint with rotation about the Z-axis.
Px: Prismatic joint with translation along the X-axis.
Py: Prismatic joint with translation along the Y-axis.
Pz: Prismatic joint with translation along the Z-axis.
Type of supported joints.
"""

F = enum.auto() # Fixed
R = enum.auto() # Revolute
P = enum.auto() # Prismatic
@staticmethod
def _generate_next_value_(name, start, count, last_values):
# Start auto Enum value from 0 instead of 1
return count

#: Fixed joint.
F = enum.auto()

# Revolute joints, single axis
Rx = enum.auto()
Ry = enum.auto()
Rz = enum.auto()
#: Revolute joint (1 DoF around axis).
R = enum.auto()

# Prismatic joints, single axis
Px = enum.auto()
Py = enum.auto()
Pz = enum.auto()
#: Prismatic joint (1 DoF along axis).
P = enum.auto()


@dataclasses.dataclass
@jax_dataclasses.pytree_dataclass
class JointDescriptor:
"""
Description of a joint type with a specific code.

Args:
code (JointType): The code representing the joint type.

Base class for joint types requiring to store additional metadata.
"""

code: JointType
#: The joint type.
joint_type: JointType

def __hash__(self) -> int:
return hash(self.__repr__())


@dataclasses.dataclass
@jax_dataclasses.pytree_dataclass
class JointGenericAxis(JointDescriptor):
"""
Description of a joint type with a generic axis.

Attributes:
axis (npt.NDArray): The axis of rotation or translation for the joint.

A joint requiring the specification of a 3D axis.
"""

axis: npt.NDArray
#: The axis of rotation or translation of the joint (must have norm 1).
axis: jtp.Vector

def __post_init__(self):
if np.allclose(self.axis, 0.0):
raise ValueError(self.axis)
def __hash__(self) -> int:
return hash((self.joint_type, tuple(np.array(self.axis).tolist())))

def __eq__(self, other):
return super().__eq__(other) and np.allclose(self.axis, other.axis)
def __eq__(self, other: JointGenericAxis) -> bool:
if not isinstance(other, JointGenericAxis):
return False

def __hash__(self) -> int:
return hash(self.__repr__())
return hash(self) == hash(other)


@jax_dataclasses.pytree_dataclass
Expand Down
20 changes: 14 additions & 6 deletions src/jaxsim/parsers/descriptions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jaxsim import logging

from ..kinematic_graph import KinematicGraph, RootPose
from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose
from .collision import CollidablePoint, CollisionShape
from .joint import JointDescription
from .link import LinkDescription
Expand Down Expand Up @@ -75,6 +75,9 @@ def build_model_from(
considered_joints=considered_joints
)

# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=kinematic_graph)

# Store here the final model collisions
final_collisions: List[CollisionShape] = []

Expand Down Expand Up @@ -121,7 +124,7 @@ def build_model_from(
# relative pose
moved_cp = cp.change_link(
new_link=real_parent_link_of_shape,
new_H_old=kinematic_graph.relative_transform(
new_H_old=fk.relative_transform(
relative_to=real_parent_link_of_shape.name,
name=cp.parent_link.name,
),
Expand All @@ -139,7 +142,9 @@ def build_model_from(
root=kinematic_graph.root,
joints=kinematic_graph.joints,
frames=kinematic_graph.frames,
_joints_removed=kinematic_graph._joints_removed,
)

assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name

return model
Expand All @@ -158,15 +163,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription":
ValueError: If the specified joints are not part of the model.
"""

msg = "The model reduction logic assumes that removed joints have zero angles"
logging.info(msg=msg)

if len(set(considered_joints) - set(self.joint_names())) != 0:
extra_joints = set(considered_joints) - set(self.joint_names())
msg = f"Found joints not part of the model: {extra_joints}"
raise ValueError(msg)

return ModelDescription.build_model_from(
reduced_model_description = ModelDescription.build_model_from(
name=self.name,
links=list(self.links_dict.values()),
joints=self.joints,
Expand All @@ -177,6 +179,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription":
considered_joints=considered_joints,
)

# Include the unconnected/removed joints from the original model.
for joint in self._joints_removed:
reduced_model_description._joints_removed.append(joint)

return reduced_model_description

def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:
"""
Enable or disable collision shapes associated with a link.
Expand Down
Loading
Loading