diff --git a/environment.yml b/environment.yml index 9b24a505b..cc6a4023b 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,7 @@ dependencies: - isort - pre-commit # [testing] - - idyntree + - idyntree >= 12.2.1 - pytest - pytest-icdiff - robot_descriptions diff --git a/setup.cfg b/setup.cfg index 33e19bd86..887110b8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,7 +73,7 @@ style = isort pre-commit testing = - idyntree + idyntree >= 12.2.1 pytest >=6.0 pytest-icdiff robot-descriptions diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 987a2e9f9..9d14a8fef 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1,9 +1,10 @@ from __future__ import annotations +import copy import dataclasses import functools import pathlib -from typing import Any +from typing import Any, Sequence import jax import jax.numpy as jnp @@ -55,7 +56,7 @@ def build_from_model_description( *, terrain: jaxsim.terrain.Terrain | None = None, is_urdf: bool | None = None, - considered_joints: list[str] | None = None, + considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. @@ -257,24 +258,50 @@ def link_names(self) -> tuple[str, ...]: # ===================== -def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel: +def reduce( + model: JaxSimModel, + considered_joints: tuple[str, ...], + locked_joint_positions: dict[str, jtp.Float] | None = None, +) -> JaxSimModel: """ Reduce the model by lumping together the links connected by removed joints. Args: model: The model to reduce. considered_joints: The sequence of joints to consider. - - Note: - If considered_joints contains joints not existing in the model, the method - will raise an exception. If considered_joints is empty, the method will - return a copy of the input model. + locked_joint_positions: + A dictionary containing the positions of the joints to be considered + in the reduction process. The removed joints in the reduced model + will have their position locked to their value in this dictionary. + If a joint is not part of the dictionary, its position is set to zero. """ + locked_joint_positions = ( + locked_joint_positions if locked_joint_positions is not None else {} + ) + + # If locked joints are passed, make sure that they are valid. + if not set(locked_joint_positions).issubset(model.joint_names()): + new_joints = set(model.joint_names()) - set(locked_joint_positions) + raise ValueError(f"Passed joints not existing in the model: {new_joints}") + + # Copy the model description with a deep copy of the joints. + intermediate_description = dataclasses.replace( + model.description.get(), joints=copy.deepcopy(model.description.get().joints) + ) + + # Update the initial position of the joints. + # This is necessary to compute the correct pose of the link pairs connected + # to removed joints. + for joint_name in set(model.joint_names()) - set(considered_joints): + j = intermediate_description.joints_dict[joint_name] + with j.mutable_context(): + j.initial_position = float(locked_joint_positions.get(joint_name, 0.0)) + # Reduce the model description. - # If considered_joints contains joints not existing in the model, the method - # will raise an exception. - reduced_intermediate_description = model.description.obj.reduce( + # If `considered_joints` contains joints not existing in the model, + # the method will raise an exception. + reduced_intermediate_description = intermediate_description.reduce( considered_joints=list(considered_joints) ) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index d5cf436fb..6e52952a2 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -15,6 +15,7 @@ JointType, ModelDescription, ) +from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from .rotation import Rotation @@ -87,21 +88,19 @@ def build(description: ModelDescription) -> JointModel: # w.r.t. the implicit __model__ SDF frame is not the identity). suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=description) + # Compute the parent-to-predecessor and successor-to-child transforms for # each joint belonging to the model. # Note that the joint indices starts from i=1 given our joint model, # therefore the entries at index 0 are not updated. for joint in ordered_joints: λ_H_pre = λ_H_pre.at[joint.index].set( - description.relative_transform( - relative_to=joint.parent.name, - name=joint.name, - ) + fk.relative_transform(relative_to=joint.parent.name, name=joint.name) ) suc_H_i = suc_H_i.at[joint.index].set( - description.relative_transform( - relative_to=joint.name, name=joint.child.name - ) + fk.relative_transform(relative_to=joint.name, name=joint.child.name) ) # Define the DoFs of the base link. @@ -243,16 +242,16 @@ def supported_joint_motion( """ if isinstance(joint_type, JointType): - code = joint_type + type_enum = joint_type elif isinstance(joint_type, JointDescriptor): - code = joint_type.code + type_enum = joint_type.joint_type else: raise ValueError(joint_type) # Prepare the joint position s = jnp.array(joint_position).astype(float) - match code: + match type_enum: case JointType.R: joint_type: JointGenericAxis @@ -276,58 +275,8 @@ def supported_joint_motion( S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)])) case JointType.F: - raise ValueError("Fixed joints shouldn't be here") - - case JointType.Rx: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_x_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 1.0, 0, 0]) - - case JointType.Ry: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_y_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 0, 1.0, 0]) - - case JointType.Rz: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_z_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 0, 0, 1.0]) - - case JointType.Px: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([s, 0.0, 0.0]), - ) - - S = jnp.vstack([1.0, 0, 0, 0, 0, 0]) - - case JointType.Py: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([0.0, s, 0.0]), - ) - - S = jnp.vstack([0, 1.0, 0, 0, 0, 0]) - - case JointType.Pz: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([0.0, 0.0, s]), - ) - - S = jnp.vstack([0, 0, 1.0, 0, 0, 0]) + pre_H_suc = jaxlie.SE3.identity() + S = jnp.zeros(shape=(6, 1)) case _: raise ValueError(joint_type) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 97db9c86c..26af59000 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import enum from typing import Tuple, Union @@ -6,79 +8,60 @@ import numpy as np import numpy.typing as npt +import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass, Mutability from .link import LinkDescription +@enum.unique class JointType(enum.IntEnum): """ - Enumeration of joint types for robot joints. - - Args: - F: Fixed joint (no movement). - R: Revolute joint (rotation). - P: Prismatic joint (translation). - Rx: Revolute joint with rotation about the X-axis. - Ry: Revolute joint with rotation about the Y-axis. - Rz: Revolute joint with rotation about the Z-axis. - Px: Prismatic joint with translation along the X-axis. - Py: Prismatic joint with translation along the Y-axis. - Pz: Prismatic joint with translation along the Z-axis. + Type of supported joints. """ - F = enum.auto() # Fixed - R = enum.auto() # Revolute - P = enum.auto() # Prismatic + @staticmethod + def _generate_next_value_(name, start, count, last_values): + # Start auto Enum value from 0 instead of 1 + return count + + #: Fixed joint. + F = enum.auto() - # Revolute joints, single axis - Rx = enum.auto() - Ry = enum.auto() - Rz = enum.auto() + #: Revolute joint (1 DoF around axis). + R = enum.auto() - # Prismatic joints, single axis - Px = enum.auto() - Py = enum.auto() - Pz = enum.auto() + #: Prismatic joint (1 DoF along axis). + P = enum.auto() -@dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class JointDescriptor: """ - Description of a joint type with a specific code. - - Args: - code (JointType): The code representing the joint type. - + Base class for joint types requiring to store additional metadata. """ - code: JointType + #: The joint type. + joint_type: JointType - def __hash__(self) -> int: - return hash(self.__repr__()) - -@dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class JointGenericAxis(JointDescriptor): """ - Description of a joint type with a generic axis. - - Attributes: - axis (npt.NDArray): The axis of rotation or translation for the joint. - + A joint requiring the specification of a 3D axis. """ - axis: npt.NDArray + #: The axis of rotation or translation of the joint (must have norm 1). + axis: jtp.Vector - def __post_init__(self): - if np.allclose(self.axis, 0.0): - raise ValueError(self.axis) + def __hash__(self) -> int: + return hash((self.joint_type, tuple(np.array(self.axis).tolist()))) - def __eq__(self, other): - return super().__eq__(other) and np.allclose(self.axis, other.axis) + def __eq__(self, other: JointGenericAxis) -> bool: + if not isinstance(other, JointGenericAxis): + return False - def __hash__(self) -> int: - return hash(self.__repr__()) + return hash(self) == hash(other) @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 9dd9e99f1..51be58c3d 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -4,7 +4,7 @@ from jaxsim import logging -from ..kinematic_graph import KinematicGraph, RootPose +from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose from .collision import CollidablePoint, CollisionShape from .joint import JointDescription from .link import LinkDescription @@ -75,6 +75,9 @@ def build_model_from( considered_joints=considered_joints ) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=kinematic_graph) + # Store here the final model collisions final_collisions: List[CollisionShape] = [] @@ -121,7 +124,7 @@ def build_model_from( # relative pose moved_cp = cp.change_link( new_link=real_parent_link_of_shape, - new_H_old=kinematic_graph.relative_transform( + new_H_old=fk.relative_transform( relative_to=real_parent_link_of_shape.name, name=cp.parent_link.name, ), @@ -139,7 +142,9 @@ def build_model_from( root=kinematic_graph.root, joints=kinematic_graph.joints, frames=kinematic_graph.frames, + _joints_removed=kinematic_graph._joints_removed, ) + assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name return model @@ -158,15 +163,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription": ValueError: If the specified joints are not part of the model. """ - msg = "The model reduction logic assumes that removed joints have zero angles" - logging.info(msg=msg) - if len(set(considered_joints) - set(self.joint_names())) != 0: extra_joints = set(considered_joints) - set(self.joint_names()) msg = f"Found joints not part of the model: {extra_joints}" raise ValueError(msg) - return ModelDescription.build_model_from( + reduced_model_description = ModelDescription.build_model_from( name=self.name, links=list(self.links_dict.values()), joints=self.joints, @@ -177,6 +179,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription": considered_joints=considered_joints, ) + # Include the unconnected/removed joints from the original model. + for joint in self._joints_removed: + reduced_model_description._joints_removed.append(joint) + + return reduced_model_description + def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None: """ Enable or disable collision shapes associated with a link. diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 5841e35d8..040c6d332 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import dataclasses import functools @@ -9,6 +11,7 @@ List, NamedTuple, Optional, + Sequence, Tuple, Union, ) @@ -41,7 +44,7 @@ def __eq__(self, other): @dataclasses.dataclass(frozen=True) -class KinematicGraph: +class KinematicGraph(Sequence[descriptions.LinkDescription]): """ Represents a kinematic graph of links and joints. @@ -76,6 +79,12 @@ class KinematicGraph: repr=False, compare=False, default_factory=dict ) + # Private attribute storing the unconnected joints from the parsed model and + # the joints removed after model reduction. + _joints_removed: list[descriptions.JointDescription] = dataclasses.field( + default_factory=list, repr=False, compare=False + ) + @functools.cached_property def links_dict(self) -> Dict[str, descriptions.LinkDescription]: return {l.name: l for l in iter(self)} @@ -153,15 +162,24 @@ def build_from( # Couple links and joints and create the graph of links. # Note that the pose of the frames is not updated; it's the caller's # responsibility to update their pose if they want to use them. - graph_root_node, graph_joints, graph_frames = KinematicGraph.create_graph( - links=links, joints=joints, root_link_name=root_link_name + graph_root_node, graph_joints, graph_frames, unconnected_joints = ( + KinematicGraph.create_graph( + links=links, joints=joints, root_link_name=root_link_name + ) ) for frame in graph_frames: logging.warning(msg=f"Ignoring unconnected link / frame: '{frame.name}'") + for joint in unconnected_joints: + logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'") + return KinematicGraph( - root=graph_root_node, joints=graph_joints, frames=[], root_pose=root_pose + root=graph_root_node, + joints=graph_joints, + frames=[], + root_pose=root_pose, + _joints_removed=unconnected_joints, ) @staticmethod @@ -173,9 +191,10 @@ def create_graph( descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription], + list[descriptions.JointDescription], ]: """ - Create a kinematic graph from lists of links and joints. + Create a kinematic graph from the lists of parsed links and joints. Args: links (List[descriptions.LinkDescription]): A list of link descriptions. @@ -183,8 +202,9 @@ def create_graph( root_link_name (str): The name of the root link. Returns: - Tuple[descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription]]: - A tuple containing the root link, list of joints, and list of frames in the graph. + A tuple containing the root node with the full kinematic graph as child nodes, + the list of joints associated to graph nodes, the list of frames rigidly + attached to graph nodes, and the list of joints not part of the graph. """ # Create a dict that maps link name to the link, for easy retrieval @@ -246,18 +266,25 @@ def create_graph( links_dict[root_link_name].mutable(mutable=False), list(set(joints) - set(removed_joints)), frames, + list(set(removed_joints)), ) - def reduce(self, considered_joints: List[str]) -> "KinematicGraph": + def reduce(self, considered_joints: List[str]) -> KinematicGraph: """ - Reduce the kinematic graph by removing specified joints and lumping the mass and inertia of removed links into their parent links. + Reduce the kinematic graph by removing unspecified joints. + + When a joint is removed, the mass and inertia of its child link are lumped + with those of its parent link, obtaining a new link that combines the two. + The description of the removed joint specifies the default angle (usually 0) + that is considered when the joint is removed. Args: - considered_joints (List[str]): A list of joint names to consider. + considered_joints: A list of joint names to consider. Returns: - KinematicGraph: The reduced kinematic graph. + The reduced kinematic graph. """ + # The current object represents the complete kinematic graph full_graph = self @@ -281,6 +308,9 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": links_dict = copy.deepcopy(full_graph.links_dict) joints_dict = copy.deepcopy(full_graph.joints_dict) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=full_graph) + # The following steps are implemented below in order to create the reduced graph: # # 1. Lump the mass of the removed links into their parent @@ -329,7 +359,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Lump the link lumped_link = parent_of_link_to_remove.lump_with( link=link_to_remove, - lumped_H_removed=full_graph.relative_transform( + lumped_H_removed=fk.relative_transform( relative_to=parent_of_link_to_remove.name, name=link_to_remove.name ), ) @@ -370,7 +400,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Update the pose. Note that after the lumping process, the dict entry # links_dict[joint.parent.name] contains the final lumped link with joint.mutable_context(mutability=Mutability.MUTABLE): - joint.pose = full_graph.relative_transform( + joint.pose = fk.relative_transform( relative_to=links_dict[joint.parent.name].name, name=joint.name ) with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): @@ -396,18 +426,25 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Create the reduced graph data. We pass the full list of links so that those # that are not part of the graph will be returned as frames. - reduced_root_node, reduced_joints, reduced_frames = KinematicGraph.create_graph( - links=list(full_graph_links_dict.values()), - joints=[joints_dict[joint_name] for joint_name in considered_joints], - root_link_name=full_graph.root.name, + reduced_root_node, reduced_joints, reduced_frames, unconnected_joints = ( + KinematicGraph.create_graph( + links=list(full_graph_links_dict.values()), + joints=[joints_dict[joint_name] for joint_name in considered_joints], + root_link_name=full_graph.root.name, + ) ) # Create the reduced graph reduced_graph = KinematicGraph( root=reduced_root_node, joints=reduced_joints, - frames=reduced_frames, + frames=self.frames + reduced_frames, root_pose=full_graph.root_pose, + _joints_removed=( + self._joints_removed + + unconnected_joints + + [joints_dict[name] for name in joint_names_to_remove] + ), ) # ================================================================ @@ -424,7 +461,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Update the connection of the frame frame.parent = new_parent_link - frame.pose = full_graph.relative_transform( + frame.pose = fk.relative_transform( relative_to=new_parent_link.name, name=frame.name ) @@ -462,65 +499,6 @@ def frame_names(self) -> List[str]: """ return list(self.frames_dict.keys()) - def transform(self, name: str) -> npt.NDArray: - """ - Compute the transformation matrix for a given link, joint, or frame. - - Args: - name (str): The name of the link, joint, or frame. - - Returns: - npt.NDArray: The transformation matrix. - """ - if name in self.transform_cache: - return self.transform_cache[name] - - if name in self.joint_names(): - joint = self.joints_dict[name] - - if joint.initial_position != 0.0: - msg = f"Ignoring unsupported initial position of joint '{name}'" - logging.warning(msg=msg) - - transform = self.transform(name=joint.parent.name) @ joint.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - if name in self.link_names(): - link = self.links_dict[name] - - if link.name == self.root.name: - return link.pose - - parent_joint = self.joints_connection_dict[(link.parent.name, link.name)] - transform = self.transform(name=parent_joint.name) @ link.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - # It can only be a plain frame - if name not in self.frame_names(): - raise ValueError(name) - - frame = self.frames_dict[name] - transform = self.transform(name=frame.parent.name) @ frame.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: - """ - Compute the relative transformation matrix between two elements in the kinematic graph. - - Args: - relative_to (str): The name of the reference element. - name (str): The name of the element to compute the relative transformation for. - - Returns: - npt.NDArray: The relative transformation matrix. - """ - return np.linalg.inv(self.transform(name=relative_to)) @ self.transform( - name=name - ) - def print_tree(self) -> None: """ Print the tree structure of the kinematic graph. @@ -574,6 +552,10 @@ def breadth_first_search( yield child + # ================= + # Sequence protocol + # ================= + def __iter__(self) -> Iterable[descriptions.LinkDescription]: yield from KinematicGraph.breadth_first_search(root=self.root) @@ -606,3 +588,188 @@ def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription: return list(iter(self))[key] raise TypeError(type(key).__name__) + + def count(self, value: descriptions.LinkDescription) -> int: + return list(iter(self)).count(value) + + def index( + self, value: descriptions.LinkDescription, start: int = 0, stop: int = -1 + ) -> int: + return list(iter(self)).index(value, start, stop) + + +# ==================== +# Other useful classes +# ==================== + + +@dataclasses.dataclass(frozen=True) +class KinematicGraphTransforms: + + graph: KinematicGraph + + _transform_cache: dict[str, npt.NDArray] = dataclasses.field( + default_factory=dict, init=False, repr=False, compare=False + ) + + _initial_joint_positions: dict[str, float] = dataclasses.field( + init=False, repr=False, compare=False + ) + + def __post_init__(self) -> None: + + super().__setattr__( + "_initial_joint_positions", + {joint.name: joint.initial_position for joint in self.graph.joints}, + ) + + @property + def initial_joint_positions(self) -> npt.NDArray: + + return np.atleast_1d( + np.array(list(self._initial_joint_positions.values())) + ).astype(float) + + @initial_joint_positions.setter + def initial_joint_positions( + self, + positions: npt.NDArray | Sequence, + joint_names: Sequence[str] | None = None, + ) -> None: + + joint_names = ( + joint_names + if joint_names is not None + else list(self._initial_joint_positions.keys()) + ) + + s = np.atleast_1d(np.array(positions).squeeze()) + + if s.size != len(joint_names): + raise ValueError(s.size, len(joint_names)) + + for joint_name in joint_names: + if joint_name not in self._initial_joint_positions: + raise ValueError(joint_name) + + # Clear transform cache. + self._transform_cache.clear() + + # Update initial joint positions. + for joint_name, position in zip(joint_names, s): + self._initial_joint_positions[joint_name] = position + + def transform(self, name: str) -> npt.NDArray: + """ + Compute the SE(3) transform of elements belonging to the kinematic graph. + + Args: + name: The name of a link, a joint, or a frame. + + Returns: + The 4x4 transform matrix of the element w.r.t. the model frame. + """ + + # If the transform was already computed, return it. + if name in self._transform_cache: + return self._transform_cache[name] + + # If the name is a joint, compute M_H_J transform. + if name in self.graph.joint_names(): + + # Get the joint. + joint = self.graph.joints_dict[name] + + # Get the transform of the parent link. + M_H_L = self.transform(name=joint.parent.name) + + # Rename the pose of the predecessor joint frame w.r.t. its parent link. + L_H_pre = joint.pose + + # Compute the joint transform from the predecessor to the successor frame. + pre_H_J = self.pre_H_suc( + joint_type=joint.jtype, + joint_position=self._initial_joint_positions[joint.name], + ) + + # Compute the M_H_J transform. + self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J + return self._transform_cache[name] + + # If the name is a link, compute M_H_L transform. + if name in self.graph.link_names(): + + # Get the link. + link = self.graph.links_dict[name] + + # Handle the pose between the __model__ frame and the root link. + if link.name == self.graph.root.name: + M_H_B = link.pose + return M_H_B + + # Get the joint between the link and its parent. + parent_joint = self.graph.joints_connection_dict[ + (link.parent.name, link.name) + ] + + # Get the transform of the parent joint. + M_H_J = self.transform(name=parent_joint.name) + + # Rename the pose of the link w.r.t. its parent joint. + J_H_L = link.pose + + # Compute the M_H_L transform. + self._transform_cache[name] = M_H_J @ J_H_L + return self._transform_cache[name] + + # It can only be a plain frame. + if name not in self.graph.frame_names(): + raise ValueError(name) + + # Get the frame. + frame = self.graph.frames_dict[name] + + # Get the transform of the parent link. + M_H_L = self.transform(name=frame.parent.name) + + # Rename the pose of the frame w.r.t. its parent link. + L_H_F = frame.pose + + # Compute the M_H_F transform. + self._transform_cache[name] = M_H_L @ L_H_F + return self._transform_cache[name] + + def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: + """ + Compute the SE(3) relative transform of elements belonging to the kinematic graph. + + Args: + relative_to: The name of the reference element. + name: The name of a link, a joint, or a frame. + + Returns: + The 4x4 transform matrix of the element w.r.t. the desired frame. + """ + + import jaxsim.math + + M_H_target = self.transform(name=name) + M_H_R = self.transform(name=relative_to) + + # Compute the relative transform R_H_target, where R is the reference frame, + # and i the frame of the desired link|joint|frame. + return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target + + @staticmethod + def pre_H_suc( + joint_type: descriptions.JointType | descriptions.JointDescriptor, + joint_position: float | None = None, + ) -> npt.NDArray: + + import jaxsim.math + + return np.array( + jaxsim.math.supported_joint_motion( + joint_type=joint_type, joint_position=joint_position + )[0] + ) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index eef416734..ed49e7702 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -134,7 +134,7 @@ def extract_model_data( name=j.name, parent=world_link, child=links_dict[j.child], - jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), + jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz) if j.axis is not None @@ -201,7 +201,7 @@ def extract_model_data( name=j.name, parent=links_dict[j.parent], child=links_dict[j.child], - jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), + jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz) if j.axis is not None diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index b281abc9f..afdc880d5 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,5 +1,4 @@ import os -from typing import Union import jaxlie import numpy as np @@ -60,57 +59,43 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: return M_L.astype(dtype=float) -def axis_to_jtype( - axis: rod.Axis, type: str -) -> Union[descriptions.JointType, descriptions.JointDescriptor]: +def joint_to_joint_type( + joint: rod.Joint, +) -> descriptions.JointType | descriptions.JointDescriptor: """ - Convert an SDF axis to a joint type. + Extract the joint type from an SDF joint. Args: - axis: The SDF axis. - type: The SDF joint type. + joint: The parsed SDF joint. Returns: The corresponding joint type description. """ - if type == "fixed": + axis = joint.axis + joint_type = joint.type + + if joint_type == "fixed": return descriptions.JointType.F if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") - axis_xyz = np.array(axis.xyz.xyz) - - if np.allclose(axis_xyz, [1, 0, 0]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Rx - - if np.allclose(axis_xyz, [0, 1, 0]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Ry - - if np.allclose(axis_xyz, [0, 0, 1]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Rz - - if np.allclose(axis_xyz, [1, 0, 0]) and type == "prismatic": - return descriptions.JointType.Px - - if np.allclose(axis_xyz, [0, 1, 0]) and type == "prismatic": - return descriptions.JointType.Py - - if np.allclose(axis_xyz, [0, 0, 1]) and type == "prismatic": - return descriptions.JointType.Pz + # Make sure that the axis is a unary vector. + axis_xyz = np.array(axis.xyz.xyz).astype(float) + axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) - if type == "revolute": + if joint_type in {"revolute", "continuous"}: return descriptions.JointGenericAxis( - code=descriptions.JointType.R, axis=np.array(axis_xyz, dtype=float) + joint_type=descriptions.JointType.R, axis=axis_xyz ) - if type == "prismatic": + if joint_type == "prismatic": return descriptions.JointGenericAxis( - code=descriptions.JointType.P, axis=np.array(axis_xyz, dtype=float) + joint_type=descriptions.JointType.P, axis=axis_xyz ) - raise ValueError("Joint not supported", axis_xyz, type) + raise ValueError("Joint not supported", axis_xyz, joint_type) def create_box_collision( diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 2c760b0bc..850959840 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -50,7 +50,9 @@ def editable(self: Self, validate: bool = True) -> Iterator[Self]: @contextlib.contextmanager def mutable_context( - self: Self, mutability: Mutability, restore_after_exception: bool = True + self: Self, + mutability: Mutability = Mutability.MUTABLE, + restore_after_exception: bool = True, ) -> Iterator[Self]: """ Context manager to temporarily change the mutability of the object. @@ -86,7 +88,7 @@ def restore_self() -> None: setattr(self, f.name, getattr(self_copy, f.name)) try: - self.set_mutability(mutability) + self.set_mutability(mutability=mutability) yield self if mutability is not Mutability.MUTABLE_NO_VALIDATION: diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 20374a5be..eba2c525f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -14,17 +14,17 @@ def test_model_creation_and_reduction( jaxsim_model_ergocub: js.model.JaxSimModel, - jaxsim_model_ergocub_reduced: js.model.JaxSimModel, + prng_key: jax.Array, ): model_full = jaxsim_model_ergocub - model_reduced = jaxsim_model_ergocub_reduced - # Build the data of the full model. - data = js.data.JaxSimModelData.build( + key, subkey = jax.random.split(prng_key, num=2) + data_full = js.data.random_model_data( model=model_full, - base_position=jnp.array([0, 0, 0.8]), + key=subkey, velocity_representation=VelRepr.Inertial, + base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)), ) # ===== @@ -32,7 +32,7 @@ def test_model_creation_and_reduction( # ===== # Check that the data of the full model is valid. - assert data.valid(model=model_full) + assert data_full.valid(model=model_full) # Build the ROD model from the original description. assert isinstance(model_full.built_from, (str, pathlib.Path)) @@ -47,27 +47,106 @@ def test_model_creation_and_reduction( # Check that all non-fixed joints are in the full model. assert set(joint_names_in_description) == set(model_full.joint_names()) + # ================ + # Reduce the model + # ================ + + # Get the names of the joints to keep in the reduced model. + reduced_joints = tuple( + j + for j in model_full.joint_names() + if "camera" not in j + and "neck" not in j + and "wrist" not in j + and "thumb" not in j + and "index" not in j + and "middle" not in j + and "ring" not in j + and "pinkie" not in j + # + and "elbow" not in j + and "shoulder" not in j + and "torso" not in j + and "r_knee" not in j + ) + + # Reduce the model. + # Note: here we also specify a non-zero position of the removed joints. + # The process should take into account the corresponding joint transforms + # when the link-joint-link chains are lumped together. + model_reduced = js.model.reduce( + model=model_full, + considered_joints=reduced_joints, + locked_joint_positions={ + name: pos + for name, pos in zip( + model_full.joint_names(), + data_full.joint_positions( + model=model_full, joint_names=model_full.joint_names() + ).tolist(), + ) + }, + ) + + # Check DoFs. + assert model_full.dofs() != model_reduced.dofs() + + # Check that all non-fixed joints are in the reduced model. + assert set(reduced_joints) == set(model_reduced.joint_names()) + # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, - base_position=jnp.array([0, 0, 0.8]), - velocity_representation=VelRepr.Inertial, + base_position=data_full.base_position(), + base_quaternion=data_full.base_orientation(dcm=False), + joint_positions=data_full.joint_positions( + model=model_full, joint_names=model_reduced.joint_names() + ), + base_linear_velocity=data_full.base_velocity()[0:3], + base_angular_velocity=data_full.base_velocity()[3:6], + joint_velocities=data_full.joint_velocities( + model=model_full, joint_names=model_reduced.joint_names() + ), + velocity_representation=data_full.velocity_representation, ) - # Check that the reduced model data is valid. - assert not data_reduced.valid(model=model_full) - assert data_reduced.valid(model=model_reduced) + # ===================== + # Test against iDynTree + # ===================== - # Check that the total mass is preserved. - assert js.model.total_mass(model=model_full) == pytest.approx( - js.model.total_mass(model=model_reduced) + kin_dyn_full = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model_full, data=data_full ) - # Check that the CoM position is preserved. - assert js.com.com_position(model=model_full, data=data) == pytest.approx( - js.com.com_position(model=model_reduced, data=data_reduced), abs=1e-6 + kin_dyn_reduced = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model_reduced, data=data_reduced ) + # Check that the total mass is preserved. + assert kin_dyn_full.total_mass() == pytest.approx(kin_dyn_reduced.total_mass()) + + # Check that the CoM position match. + assert kin_dyn_full.com_position() == pytest.approx(kin_dyn_reduced.com_position()) + assert kin_dyn_full.com_position() == pytest.approx( + js.com.com_position(model=model_reduced, data=data_reduced) + ) + + # Check that link transforms match. + for link_name, link_idx in zip( + model_reduced.link_names(), + js.link.names_to_idxs( + model=model_reduced, link_names=model_reduced.link_names() + ), + ): + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( + kin_dyn_full.frame_transform(frame_name=link_name) + ) + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( + js.link.transform( + model=model_reduced, data=data_reduced, link_index=link_idx + ) + ) + def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 5344ad15b..7ec6ae1db 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -12,7 +12,10 @@ def build_kindyncomputations_from_jaxsim_model( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + considered_joints: list[str] | None = None, + removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: """ Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`. @@ -20,6 +23,10 @@ def build_kindyncomputations_from_jaxsim_model( Args: model: The `JaxSimModel` from which to build the `KinDynComputations`. data: The `JaxSimModelData` from which to build the `KinDynComputations`. + considered_joints: + The list of joint names to consider in the `KinDynComputations`. + removed_joint_positions: + A dictionary defining the positions of the removed joints (default is 0). Returns: The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`. @@ -34,12 +41,41 @@ def build_kindyncomputations_from_jaxsim_model( ) or (isinstance(model.built_from, str) and " "KinDynComputations": + removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, + ) -> KinDynComputations: # Read the URDF description urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf @@ -99,11 +139,20 @@ def build( # Create the model loader mdl_loader = idt.ModelLoader() + # Handle removed_joint_positions if None + removed_joint_positions = ( + {str(name): float(pos) for name, pos in removed_joint_positions.items()} + if removed_joint_positions is not None + else dict() + ) + # Load the URDF description if not ( mdl_loader.loadModelFromString(urdf_string) if considered_joints is None - else mdl_loader.loadReducedModelFromString(urdf_string, considered_joints) + else mdl_loader.loadReducedModelFromString( + urdf_string, considered_joints, removed_joint_positions + ) ): raise RuntimeError("Failed to load URDF description") @@ -197,6 +246,13 @@ def link_names(self) -> list[str]: self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks()) ] + def frame_names(self) -> list[str]: + + return [ + self.kin_dyn.getFrameName(i) + for i in range(self.kin_dyn.getNrOfLinks(), self.kin_dyn.getNrOfFrames()) + ] + def joint_positions(self) -> npt.NDArray: vector = idt.VectorDynSize() @@ -272,6 +328,26 @@ def frame_transform(self, frame_name: str) -> npt.NDArray: return H + def frame_relative_transform( + self, ref_frame_name: str, frame_name: str + ) -> npt.NDArray: + + if self.kin_dyn.getFrameIndex(ref_frame_name) < 0: + raise ValueError(f"Frame '{ref_frame_name}' does not exist") + + if self.kin_dyn.getFrameIndex(frame_name) < 0: + raise ValueError(f"Frame '{frame_name}' does not exist") + + ref_H_frame: idt.Transform = self.kin_dyn.getRelativeTransform( + ref_frame_name, frame_name + ) + + H = np.eye(4) + H[0:3, 3] = ref_H_frame.getPosition().toNumPy() + H[0:3, 0:3] = ref_H_frame.getRotation().toNumPy() + + return H + def base_velocity(self) -> npt.NDArray: nu = idt.VectorDynSize()