From 97c1c65bdf588214f11aa665e2be44b7495f5a58 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 12 Apr 2024 16:26:17 +0200 Subject: [PATCH 01/17] Move computation of KinematicGraph transforms to new class --- src/jaxsim/math/joint_model.py | 13 +- src/jaxsim/parsers/descriptions/model.py | 7 +- src/jaxsim/parsers/kinematic_graph.py | 144 +++++++++++++---------- 3 files changed, 93 insertions(+), 71 deletions(-) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index d5cf436fb..0bd8c956c 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -15,6 +15,7 @@ JointType, ModelDescription, ) +from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from .rotation import Rotation @@ -87,21 +88,19 @@ def build(description: ModelDescription) -> JointModel: # w.r.t. the implicit __model__ SDF frame is not the identity). suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=description) + # Compute the parent-to-predecessor and successor-to-child transforms for # each joint belonging to the model. # Note that the joint indices starts from i=1 given our joint model, # therefore the entries at index 0 are not updated. for joint in ordered_joints: λ_H_pre = λ_H_pre.at[joint.index].set( - description.relative_transform( - relative_to=joint.parent.name, - name=joint.name, - ) + fk.relative_transform(relative_to=joint.parent.name, name=joint.name) ) suc_H_i = suc_H_i.at[joint.index].set( - description.relative_transform( - relative_to=joint.name, name=joint.child.name - ) + fk.relative_transform(relative_to=joint.name, name=joint.child.name) ) # Define the DoFs of the base link. diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 9dd9e99f1..666571db7 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -4,7 +4,7 @@ from jaxsim import logging -from ..kinematic_graph import KinematicGraph, RootPose +from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose from .collision import CollidablePoint, CollisionShape from .joint import JointDescription from .link import LinkDescription @@ -75,6 +75,9 @@ def build_model_from( considered_joints=considered_joints ) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=kinematic_graph) + # Store here the final model collisions final_collisions: List[CollisionShape] = [] @@ -121,7 +124,7 @@ def build_model_from( # relative pose moved_cp = cp.change_link( new_link=real_parent_link_of_shape, - new_H_old=kinematic_graph.relative_transform( + new_H_old=fk.relative_transform( relative_to=real_parent_link_of_shape.name, name=cp.parent_link.name, ), diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 5841e35d8..7f04b752f 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -281,6 +281,9 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": links_dict = copy.deepcopy(full_graph.links_dict) joints_dict = copy.deepcopy(full_graph.joints_dict) + # Create the object to compute forward kinematics. + fk = KinematicGraphTransforms(graph=full_graph) + # The following steps are implemented below in order to create the reduced graph: # # 1. Lump the mass of the removed links into their parent @@ -329,7 +332,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Lump the link lumped_link = parent_of_link_to_remove.lump_with( link=link_to_remove, - lumped_H_removed=full_graph.relative_transform( + lumped_H_removed=fk.relative_transform( relative_to=parent_of_link_to_remove.name, name=link_to_remove.name ), ) @@ -370,7 +373,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Update the pose. Note that after the lumping process, the dict entry # links_dict[joint.parent.name] contains the final lumped link with joint.mutable_context(mutability=Mutability.MUTABLE): - joint.pose = full_graph.relative_transform( + joint.pose = fk.relative_transform( relative_to=links_dict[joint.parent.name].name, name=joint.name ) with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): @@ -424,7 +427,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Update the connection of the frame frame.parent = new_parent_link - frame.pose = full_graph.relative_transform( + frame.pose = fk.relative_transform( relative_to=new_parent_link.name, name=frame.name ) @@ -462,65 +465,6 @@ def frame_names(self) -> List[str]: """ return list(self.frames_dict.keys()) - def transform(self, name: str) -> npt.NDArray: - """ - Compute the transformation matrix for a given link, joint, or frame. - - Args: - name (str): The name of the link, joint, or frame. - - Returns: - npt.NDArray: The transformation matrix. - """ - if name in self.transform_cache: - return self.transform_cache[name] - - if name in self.joint_names(): - joint = self.joints_dict[name] - - if joint.initial_position != 0.0: - msg = f"Ignoring unsupported initial position of joint '{name}'" - logging.warning(msg=msg) - - transform = self.transform(name=joint.parent.name) @ joint.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - if name in self.link_names(): - link = self.links_dict[name] - - if link.name == self.root.name: - return link.pose - - parent_joint = self.joints_connection_dict[(link.parent.name, link.name)] - transform = self.transform(name=parent_joint.name) @ link.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - # It can only be a plain frame - if name not in self.frame_names(): - raise ValueError(name) - - frame = self.frames_dict[name] - transform = self.transform(name=frame.parent.name) @ frame.pose - self.transform_cache[name] = transform - return self.transform_cache[name] - - def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: - """ - Compute the relative transformation matrix between two elements in the kinematic graph. - - Args: - relative_to (str): The name of the reference element. - name (str): The name of the element to compute the relative transformation for. - - Returns: - npt.NDArray: The relative transformation matrix. - """ - return np.linalg.inv(self.transform(name=relative_to)) @ self.transform( - name=name - ) - def print_tree(self) -> None: """ Print the tree structure of the kinematic graph. @@ -606,3 +550,79 @@ def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription: return list(iter(self))[key] raise TypeError(type(key).__name__) + + +# ==================== +# Other useful classes +# ==================== + + +@dataclasses.dataclass(frozen=True) +class KinematicGraphTransforms: + + graph: KinematicGraph + + transform_cache: dict[str, npt.NDArray] = dataclasses.field( + default_factory=dict, init=False, repr=False, compare=False + ) + + def transform(self, name: str) -> npt.NDArray: + """ + Compute the transformation matrix for a given link, joint, or frame. + + Args: + name (str): The name of the link, joint, or frame. + + Returns: + npt.NDArray: The transformation matrix. + """ + if name in self.transform_cache: + return self.transform_cache[name] + + if name in self.graph.joint_names(): + joint = self.graph.joints_dict[name] + + if joint.initial_position != 0.0: + msg = f"Ignoring unsupported initial position of joint '{name}'" + logging.warning(msg=msg) + + transform = self.transform(name=joint.parent.name) @ joint.pose + self.transform_cache[name] = transform + return self.transform_cache[name] + + if name in self.graph.link_names(): + link = self.graph.links_dict[name] + + if link.name == self.graph.root.name: + return link.pose + + parent_joint = self.graph.joints_connection_dict[ + (link.parent.name, link.name) + ] + transform = self.transform(name=parent_joint.name) @ link.pose + self.transform_cache[name] = transform + return self.transform_cache[name] + + # It can only be a plain frame + if name not in self.graph.frame_names(): + raise ValueError(name) + + frame = self.graph.frames_dict[name] + transform = self.transform(name=frame.parent.name) @ frame.pose + self.transform_cache[name] = transform + return self.transform_cache[name] + + def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: + """ + Compute the relative transformation matrix between two elements in the kinematic graph. + + Args: + relative_to (str): The name of the reference element. + name (str): The name of the element to compute the relative transformation for. + + Returns: + npt.NDArray: The relative transformation matrix. + """ + return np.linalg.inv(self.transform(name=relative_to)) @ self.transform( + name=name + ) From e9dcb479f8fc58d5d1e215230ccbe47239c84b52 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 12 Apr 2024 17:18:54 +0200 Subject: [PATCH 02/17] Simplify handling of joint types --- src/jaxsim/math/joint_model.py | 57 +---------------- src/jaxsim/parsers/descriptions/joint.py | 78 ++++++++++-------------- src/jaxsim/parsers/rod/parser.py | 4 +- src/jaxsim/parsers/rod/utils.py | 49 ++++++--------- 4 files changed, 53 insertions(+), 135 deletions(-) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 0bd8c956c..dbd684b23 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -242,16 +242,16 @@ def supported_joint_motion( """ if isinstance(joint_type, JointType): - code = joint_type + type_enum = joint_type elif isinstance(joint_type, JointDescriptor): - code = joint_type.code + type_enum = joint_type.joint_type else: raise ValueError(joint_type) # Prepare the joint position s = jnp.array(joint_position).astype(float) - match code: + match type_enum: case JointType.R: joint_type: JointGenericAxis @@ -277,57 +277,6 @@ def supported_joint_motion( case JointType.F: raise ValueError("Fixed joints shouldn't be here") - case JointType.Rx: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_x_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 1.0, 0, 0]) - - case JointType.Ry: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_y_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 0, 1.0, 0]) - - case JointType.Rz: - - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_z_radians(theta=s) - ) - - S = jnp.vstack([0, 0, 0, 0, 0, 1.0]) - - case JointType.Px: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([s, 0.0, 0.0]), - ) - - S = jnp.vstack([1.0, 0, 0, 0, 0, 0]) - - case JointType.Py: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([0.0, s, 0.0]), - ) - - S = jnp.vstack([0, 1.0, 0, 0, 0, 0]) - - case JointType.Pz: - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array([0.0, 0.0, s]), - ) - - S = jnp.vstack([0, 0, 1.0, 0, 0, 0]) - case _: raise ValueError(joint_type) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 97db9c86c..d56765826 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import enum from typing import Tuple, Union @@ -6,79 +8,61 @@ import numpy as np import numpy.typing as npt +import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass, Mutability from .link import LinkDescription +@enum.unique +@enum.verify(enum.CONTINUOUS) class JointType(enum.IntEnum): """ - Enumeration of joint types for robot joints. - - Args: - F: Fixed joint (no movement). - R: Revolute joint (rotation). - P: Prismatic joint (translation). - Rx: Revolute joint with rotation about the X-axis. - Ry: Revolute joint with rotation about the Y-axis. - Rz: Revolute joint with rotation about the Z-axis. - Px: Prismatic joint with translation along the X-axis. - Py: Prismatic joint with translation along the Y-axis. - Pz: Prismatic joint with translation along the Z-axis. + Type of supported joints. """ - F = enum.auto() # Fixed - R = enum.auto() # Revolute - P = enum.auto() # Prismatic + @staticmethod + def _generate_next_value_(name, start, count, last_values): + # Start auto Enum value from 0 instead of 1 + return count + + #: Fixed joint. + F = enum.auto() - # Revolute joints, single axis - Rx = enum.auto() - Ry = enum.auto() - Rz = enum.auto() + #: Revolute joint (1 DoF around axis). + R = enum.auto() - # Prismatic joints, single axis - Px = enum.auto() - Py = enum.auto() - Pz = enum.auto() + #: Prismatic joint (1 DoF along axis). + P = enum.auto() -@dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class JointDescriptor: """ - Description of a joint type with a specific code. - - Args: - code (JointType): The code representing the joint type. - + Base class for joint types requiring to store additional metadata. """ - code: JointType + #: The joint type. + joint_type: JointType - def __hash__(self) -> int: - return hash(self.__repr__()) - -@dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class JointGenericAxis(JointDescriptor): """ - Description of a joint type with a generic axis. - - Attributes: - axis (npt.NDArray): The axis of rotation or translation for the joint. - + A joint requiring the specification of a 3D axis. """ - axis: npt.NDArray + #: The axis of rotation or translation of the joint (must have norm 1). + axis: jtp.Vector - def __post_init__(self): - if np.allclose(self.axis, 0.0): - raise ValueError(self.axis) + def __hash__(self) -> int: + return hash((self.joint_type, tuple(np.array(self.axis).tolist()))) - def __eq__(self, other): - return super().__eq__(other) and np.allclose(self.axis, other.axis) + def __eq__(self, other: JointGenericAxis) -> bool: + if not isinstance(other, JointGenericAxis): + return False - def __hash__(self) -> int: - return hash(self.__repr__()) + return hash(self) == hash(other) @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index eef416734..ed49e7702 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -134,7 +134,7 @@ def extract_model_data( name=j.name, parent=world_link, child=links_dict[j.child], - jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), + jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz) if j.axis is not None @@ -201,7 +201,7 @@ def extract_model_data( name=j.name, parent=links_dict[j.parent], child=links_dict[j.child], - jtype=utils.axis_to_jtype(axis=j.axis, type=j.type), + jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz) if j.axis is not None diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index b281abc9f..afdc880d5 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,5 +1,4 @@ import os -from typing import Union import jaxlie import numpy as np @@ -60,57 +59,43 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: return M_L.astype(dtype=float) -def axis_to_jtype( - axis: rod.Axis, type: str -) -> Union[descriptions.JointType, descriptions.JointDescriptor]: +def joint_to_joint_type( + joint: rod.Joint, +) -> descriptions.JointType | descriptions.JointDescriptor: """ - Convert an SDF axis to a joint type. + Extract the joint type from an SDF joint. Args: - axis: The SDF axis. - type: The SDF joint type. + joint: The parsed SDF joint. Returns: The corresponding joint type description. """ - if type == "fixed": + axis = joint.axis + joint_type = joint.type + + if joint_type == "fixed": return descriptions.JointType.F if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") - axis_xyz = np.array(axis.xyz.xyz) - - if np.allclose(axis_xyz, [1, 0, 0]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Rx - - if np.allclose(axis_xyz, [0, 1, 0]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Ry - - if np.allclose(axis_xyz, [0, 0, 1]) and type in {"revolute", "continuous"}: - return descriptions.JointType.Rz - - if np.allclose(axis_xyz, [1, 0, 0]) and type == "prismatic": - return descriptions.JointType.Px - - if np.allclose(axis_xyz, [0, 1, 0]) and type == "prismatic": - return descriptions.JointType.Py - - if np.allclose(axis_xyz, [0, 0, 1]) and type == "prismatic": - return descriptions.JointType.Pz + # Make sure that the axis is a unary vector. + axis_xyz = np.array(axis.xyz.xyz).astype(float) + axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) - if type == "revolute": + if joint_type in {"revolute", "continuous"}: return descriptions.JointGenericAxis( - code=descriptions.JointType.R, axis=np.array(axis_xyz, dtype=float) + joint_type=descriptions.JointType.R, axis=axis_xyz ) - if type == "prismatic": + if joint_type == "prismatic": return descriptions.JointGenericAxis( - code=descriptions.JointType.P, axis=np.array(axis_xyz, dtype=float) + joint_type=descriptions.JointType.P, axis=axis_xyz ) - raise ValueError("Joint not supported", axis_xyz, type) + raise ValueError("Joint not supported", axis_xyz, joint_type) def create_box_collision( From ecb0788516dc6be4a9d9021ea710ad74edf2eabe Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 12 Apr 2024 11:09:58 +0200 Subject: [PATCH 03/17] Finalize implementation of the Sequence protocol in KinematicGraph --- src/jaxsim/parsers/kinematic_graph.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 7f04b752f..5d4d2700c 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -9,6 +9,7 @@ List, NamedTuple, Optional, + Sequence, Tuple, Union, ) @@ -41,7 +42,7 @@ def __eq__(self, other): @dataclasses.dataclass(frozen=True) -class KinematicGraph: +class KinematicGraph(Sequence[descriptions.LinkDescription]): """ Represents a kinematic graph of links and joints. @@ -518,6 +519,10 @@ def breadth_first_search( yield child + # ================= + # Sequence protocol + # ================= + def __iter__(self) -> Iterable[descriptions.LinkDescription]: yield from KinematicGraph.breadth_first_search(root=self.root) @@ -551,6 +556,14 @@ def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription: raise TypeError(type(key).__name__) + def count(self, value: descriptions.LinkDescription) -> int: + return list(iter(self)).count(value) + + def index( + self, value: descriptions.LinkDescription, start: int = 0, stop: int = -1 + ) -> int: + return list(iter(self)).index(value, start, stop) + # ==================== # Other useful classes From 6b346acb583d6500b3a5ad03f4127eb8c8ba08b4 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 12 Apr 2024 11:16:57 +0200 Subject: [PATCH 04/17] Refactor KinematicGraphTransforms class --- src/jaxsim/parsers/kinematic_graph.py | 79 +++++++++++++++++++++------ 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 5d4d2700c..ba77e160b 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -581,61 +581,104 @@ class KinematicGraphTransforms: def transform(self, name: str) -> npt.NDArray: """ - Compute the transformation matrix for a given link, joint, or frame. + Compute the SE(3) transform of elements belonging to the kinematic graph. Args: - name (str): The name of the link, joint, or frame. + name: The name of a link, a joint, or a frame. Returns: - npt.NDArray: The transformation matrix. + The 4x4 transform matrix of the element w.r.t. the model frame. """ + + # If the transform was already computed, return it. if name in self.transform_cache: return self.transform_cache[name] + # If the name is a joint, compute M_H_J transform. if name in self.graph.joint_names(): + + # Get the joint. joint = self.graph.joints_dict[name] if joint.initial_position != 0.0: msg = f"Ignoring unsupported initial position of joint '{name}'" logging.warning(msg=msg) - transform = self.transform(name=joint.parent.name) @ joint.pose - self.transform_cache[name] = transform + # Get the transform of the parent link. + M_H_L = self.transform(name=joint.parent.name) + + # Rename the pose of the predecessor joint frame w.r.t. its parent link. + L_H_pre = joint.pose + + # Compute the joint transform from the predecessor to the successor frame. + # Note: we assume that the joint angle is always 0. + pre_H_J = np.eye(4) + + # Compute the M_H_J transform. + self.transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J return self.transform_cache[name] + # If the name is a link, compute M_H_L transform. if name in self.graph.link_names(): + + # Get the link. link = self.graph.links_dict[name] + # Handle the pose between the __model__ frame and the root link. if link.name == self.graph.root.name: - return link.pose + M_H_B = link.pose + return M_H_B + # Get the joint between the link and its parent. parent_joint = self.graph.joints_connection_dict[ (link.parent.name, link.name) ] - transform = self.transform(name=parent_joint.name) @ link.pose - self.transform_cache[name] = transform + + # Get the transform of the parent joint. + M_H_J = self.transform(name=parent_joint.name) + + # Rename the pose of the link w.r.t. its parent joint. + J_H_L = link.pose + + # Compute the M_H_L transform. + self.transform_cache[name] = M_H_J @ J_H_L return self.transform_cache[name] - # It can only be a plain frame + # It can only be a plain frame. if name not in self.graph.frame_names(): raise ValueError(name) + # Get the frame. frame = self.graph.frames_dict[name] - transform = self.transform(name=frame.parent.name) @ frame.pose - self.transform_cache[name] = transform + + # Get the transform of the parent link. + M_H_L = self.transform(name=frame.parent.name) + + # Rename the pose of the frame w.r.t. its parent link. + L_H_F = frame.pose + + # Compute the M_H_F transform. + self.transform_cache[name] = M_H_L @ L_H_F return self.transform_cache[name] def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: """ - Compute the relative transformation matrix between two elements in the kinematic graph. + Compute the SE(3) relative transform of elements belonging to the kinematic graph. Args: - relative_to (str): The name of the reference element. - name (str): The name of the element to compute the relative transformation for. + relative_to: The name of the reference element. + name: The name of a link, a joint, or a frame. Returns: - npt.NDArray: The relative transformation matrix. + The 4x4 transform matrix of the element w.r.t. the desired frame. """ - return np.linalg.inv(self.transform(name=relative_to)) @ self.transform( - name=name - ) + + import jaxsim.math + + M_H_target = self.transform(name=name) + M_H_R = self.transform(name=relative_to) + + # Compute the relative transform R_H_target, where R is the reference frame, + # and i the frame of the desired link|joint|frame. + return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target + From 233855e738d4f9c59dfdd85ff1bdcf597504e07e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 12 Apr 2024 16:27:10 +0200 Subject: [PATCH 05/17] Compute KinematicGraph transforms specifying generic joint positions --- src/jaxsim/math/joint_model.py | 3 +- src/jaxsim/parsers/kinematic_graph.py | 88 ++++++++++++++++++++++----- 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index dbd684b23..6e52952a2 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -275,7 +275,8 @@ def supported_joint_motion( S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)])) case JointType.F: - raise ValueError("Fixed joints shouldn't be here") + pre_H_suc = jaxlie.SE3.identity() + S = jnp.zeros(shape=(6, 1)) case _: raise ValueError(joint_type) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index ba77e160b..018a83f63 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -575,10 +575,57 @@ class KinematicGraphTransforms: graph: KinematicGraph - transform_cache: dict[str, npt.NDArray] = dataclasses.field( + _transform_cache: dict[str, npt.NDArray] = dataclasses.field( default_factory=dict, init=False, repr=False, compare=False ) + _initial_joint_positions: dict[str, float] = dataclasses.field( + init=False, repr=False, compare=False + ) + + def __post_init__(self) -> None: + + super().__setattr__( + "_initial_joint_positions", + {joint.name: joint.initial_position for joint in self.graph.joints}, + ) + + @property + def initial_joint_positions(self) -> npt.NDArray: + + return np.atleast_1d( + np.array(list(self._initial_joint_positions.values())) + ).astype(float) + + @initial_joint_positions.setter + def initial_joint_positions( + self, + positions: npt.NDArray | Sequence, + joint_names: Sequence[str] | None = None, + ) -> None: + + joint_names = ( + joint_names + if joint_names is not None + else list(self._initial_joint_positions.keys()) + ) + + s = np.atleast_1d(np.array(positions).squeeze()) + + if s.size != len(joint_names): + raise ValueError(s.size, len(joint_names)) + + for joint_name in joint_names: + if joint_name not in self._initial_joint_positions: + raise ValueError(joint_name) + + # Clear transform cache. + self._transform_cache.clear() + + # Update initial joint positions. + for joint_name, position in zip(joint_names, s): + self._initial_joint_positions[joint_name] = position + def transform(self, name: str) -> npt.NDArray: """ Compute the SE(3) transform of elements belonging to the kinematic graph. @@ -591,8 +638,8 @@ def transform(self, name: str) -> npt.NDArray: """ # If the transform was already computed, return it. - if name in self.transform_cache: - return self.transform_cache[name] + if name in self._transform_cache: + return self._transform_cache[name] # If the name is a joint, compute M_H_J transform. if name in self.graph.joint_names(): @@ -600,10 +647,6 @@ def transform(self, name: str) -> npt.NDArray: # Get the joint. joint = self.graph.joints_dict[name] - if joint.initial_position != 0.0: - msg = f"Ignoring unsupported initial position of joint '{name}'" - logging.warning(msg=msg) - # Get the transform of the parent link. M_H_L = self.transform(name=joint.parent.name) @@ -611,12 +654,14 @@ def transform(self, name: str) -> npt.NDArray: L_H_pre = joint.pose # Compute the joint transform from the predecessor to the successor frame. - # Note: we assume that the joint angle is always 0. - pre_H_J = np.eye(4) + pre_H_J = self.pre_H_suc( + joint_type=joint.jtype, + joint_position=self._initial_joint_positions[joint.name], + ) # Compute the M_H_J transform. - self.transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J - return self.transform_cache[name] + self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J + return self._transform_cache[name] # If the name is a link, compute M_H_L transform. if name in self.graph.link_names(): @@ -641,8 +686,8 @@ def transform(self, name: str) -> npt.NDArray: J_H_L = link.pose # Compute the M_H_L transform. - self.transform_cache[name] = M_H_J @ J_H_L - return self.transform_cache[name] + self._transform_cache[name] = M_H_J @ J_H_L + return self._transform_cache[name] # It can only be a plain frame. if name not in self.graph.frame_names(): @@ -658,8 +703,8 @@ def transform(self, name: str) -> npt.NDArray: L_H_F = frame.pose # Compute the M_H_F transform. - self.transform_cache[name] = M_H_L @ L_H_F - return self.transform_cache[name] + self._transform_cache[name] = M_H_L @ L_H_F + return self._transform_cache[name] def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: """ @@ -682,3 +727,16 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: # and i the frame of the desired link|joint|frame. return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target + @staticmethod + def pre_H_suc( + joint_type: descriptions.JointType | descriptions.JointDescriptor, + joint_position: float | None = None, + ) -> npt.NDArray: + + import jaxsim.math + + return np.array( + jaxsim.math.supported_joint_motion( + joint_type=joint_type, joint_position=joint_position + )[0] + ) From 2b05684bc05bcb2bd1d6896ed49c44226b2c76d4 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 15 Apr 2024 13:00:35 +0200 Subject: [PATCH 06/17] Keep track of removed joints due to lumping --- src/jaxsim/parsers/descriptions/model.py | 13 ++++--- src/jaxsim/parsers/kinematic_graph.py | 45 ++++++++++++++++++------ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 666571db7..51be58c3d 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -142,7 +142,9 @@ def build_model_from( root=kinematic_graph.root, joints=kinematic_graph.joints, frames=kinematic_graph.frames, + _joints_removed=kinematic_graph._joints_removed, ) + assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name return model @@ -161,15 +163,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription": ValueError: If the specified joints are not part of the model. """ - msg = "The model reduction logic assumes that removed joints have zero angles" - logging.info(msg=msg) - if len(set(considered_joints) - set(self.joint_names())) != 0: extra_joints = set(considered_joints) - set(self.joint_names()) msg = f"Found joints not part of the model: {extra_joints}" raise ValueError(msg) - return ModelDescription.build_model_from( + reduced_model_description = ModelDescription.build_model_from( name=self.name, links=list(self.links_dict.values()), joints=self.joints, @@ -180,6 +179,12 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription": considered_joints=considered_joints, ) + # Include the unconnected/removed joints from the original model. + for joint in self._joints_removed: + reduced_model_description._joints_removed.append(joint) + + return reduced_model_description + def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None: """ Enable or disable collision shapes associated with a link. diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 018a83f63..9c3d1bb8b 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -77,6 +77,12 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]): repr=False, compare=False, default_factory=dict ) + # Private attribute storing the unconnected joints from the parsed model and + # the joints removed after model reduction. + _joints_removed: list[descriptions.JointDescription] = dataclasses.field( + default_factory=list, repr=False, compare=False + ) + @functools.cached_property def links_dict(self) -> Dict[str, descriptions.LinkDescription]: return {l.name: l for l in iter(self)} @@ -154,15 +160,24 @@ def build_from( # Couple links and joints and create the graph of links. # Note that the pose of the frames is not updated; it's the caller's # responsibility to update their pose if they want to use them. - graph_root_node, graph_joints, graph_frames = KinematicGraph.create_graph( - links=links, joints=joints, root_link_name=root_link_name + graph_root_node, graph_joints, graph_frames, unconnected_joints = ( + KinematicGraph.create_graph( + links=links, joints=joints, root_link_name=root_link_name + ) ) for frame in graph_frames: logging.warning(msg=f"Ignoring unconnected link / frame: '{frame.name}'") + for joint in unconnected_joints: + logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'") + return KinematicGraph( - root=graph_root_node, joints=graph_joints, frames=[], root_pose=root_pose + root=graph_root_node, + joints=graph_joints, + frames=[], + root_pose=root_pose, + _joints_removed=unconnected_joints, ) @staticmethod @@ -174,9 +189,10 @@ def create_graph( descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription], + list[descriptions.JointDescription], ]: """ - Create a kinematic graph from lists of links and joints. + Create a kinematic graph from the lists of parsed links and joints. Args: links (List[descriptions.LinkDescription]): A list of link descriptions. @@ -184,8 +200,9 @@ def create_graph( root_link_name (str): The name of the root link. Returns: - Tuple[descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription]]: - A tuple containing the root link, list of joints, and list of frames in the graph. + A tuple containing the root node with the full kinematic graph as child nodes, + the list of joints associated to graph nodes, the list of frames rigidly + attached to graph nodes, and the list of joints not part of the graph. """ # Create a dict that maps link name to the link, for easy retrieval @@ -247,6 +264,7 @@ def create_graph( links_dict[root_link_name].mutable(mutable=False), list(set(joints) - set(removed_joints)), frames, + list(set(removed_joints)), ) def reduce(self, considered_joints: List[str]) -> "KinematicGraph": @@ -400,10 +418,12 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": # Create the reduced graph data. We pass the full list of links so that those # that are not part of the graph will be returned as frames. - reduced_root_node, reduced_joints, reduced_frames = KinematicGraph.create_graph( - links=list(full_graph_links_dict.values()), - joints=[joints_dict[joint_name] for joint_name in considered_joints], - root_link_name=full_graph.root.name, + reduced_root_node, reduced_joints, reduced_frames, unconnected_joints = ( + KinematicGraph.create_graph( + links=list(full_graph_links_dict.values()), + joints=[joints_dict[joint_name] for joint_name in considered_joints], + root_link_name=full_graph.root.name, + ) ) # Create the reduced graph @@ -412,6 +432,11 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": joints=reduced_joints, frames=reduced_frames, root_pose=full_graph.root_pose, + _joints_removed=( + self._joints_removed + + unconnected_joints + + [joints_dict[name] for name in joint_names_to_remove] + ), ) # ================================================================ From 50edb640c5b73e7730881f5babcfb0add8e06b83 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 15 Apr 2024 12:56:48 +0200 Subject: [PATCH 07/17] When reducing a model, allow passing the positions of removed joints --- src/jaxsim/api/model.py | 45 ++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 987a2e9f9..e74245139 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import dataclasses import functools import pathlib @@ -257,24 +258,50 @@ def link_names(self) -> tuple[str, ...]: # ===================== -def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel: +def reduce( + model: JaxSimModel, + considered_joints: tuple[str, ...], + joint_positions_locked: dict[str, jtp.Float] | None = None, +) -> JaxSimModel: """ Reduce the model by lumping together the links connected by removed joints. Args: model: The model to reduce. considered_joints: The sequence of joints to consider. - - Note: - If considered_joints contains joints not existing in the model, the method - will raise an exception. If considered_joints is empty, the method will - return a copy of the input model. + joint_positions_locked: + A dictionary containing the positions of the joints to be considered + in the reduction process. The removed joints in the reduced model + will have their position locked to the value in this dictionary. + If a joint is not present in the dictionary, its position is set to zero. """ + joint_positions_locked = ( + joint_positions_locked if joint_positions_locked is not None else {} + ) + + # If locked joints are passed, make sure that they are valid. + if not set(joint_positions_locked).issubset(model.joint_names()): + new_joints = set(model.joint_names()) - set(joint_positions_locked) + raise ValueError(f"Passed joints not existing in the model: {new_joints}") + + # Copy the model description with a deep copy of the joints. + intermediate_description = dataclasses.replace( + model.description.get(), joints=copy.deepcopy(model.description.get().joints) + ) + + # Update the initial position of the joints. + # This is necessary to compute the correct pose of the link pairs connected + # to removed joints. + for joint_name in model.joint_names(): + j = intermediate_description.joints_dict[joint_name] + with j.mutable_context(): + j.initial_position = float(joint_positions_locked.get(joint_name, 0.0)) + # Reduce the model description. - # If considered_joints contains joints not existing in the model, the method - # will raise an exception. - reduced_intermediate_description = model.description.obj.reduce( + # If `considered_joints` contains joints not existing in the model, + # the method will raise an exception. + reduced_intermediate_description = intermediate_description.reduce( considered_joints=list(considered_joints) ) From 762d56cbf21b571f139273b74f6d4fad9651f96e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 17 Apr 2024 17:58:36 +0200 Subject: [PATCH 08/17] Add default mutability in JaxsimDataclass.mutable_context --- src/jaxsim/utils/jaxsim_dataclass.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 2c760b0bc..850959840 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -50,7 +50,9 @@ def editable(self: Self, validate: bool = True) -> Iterator[Self]: @contextlib.contextmanager def mutable_context( - self: Self, mutability: Mutability, restore_after_exception: bool = True + self: Self, + mutability: Mutability = Mutability.MUTABLE, + restore_after_exception: bool = True, ) -> Iterator[Self]: """ Context manager to temporarily change the mutability of the object. @@ -86,7 +88,7 @@ def restore_self() -> None: setattr(self, f.name, getattr(self_copy, f.name)) try: - self.set_mutability(mutability) + self.set_mutability(mutability=mutability) yield self if mutability is not Mutability.MUTABLE_NO_VALIDATION: From a4214372c1ebf27c52c2fcb37fbb2e78c3c38b47 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 17 Apr 2024 17:59:42 +0200 Subject: [PATCH 09/17] Extend model reduction test --- tests/test_api_model.py | 97 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 7 deletions(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 20374a5be..4697973be 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -14,17 +14,17 @@ def test_model_creation_and_reduction( jaxsim_model_ergocub: js.model.JaxSimModel, - jaxsim_model_ergocub_reduced: js.model.JaxSimModel, + prng_key: jax.Array, ): model_full = jaxsim_model_ergocub - model_reduced = jaxsim_model_ergocub_reduced - # Build the data of the full model. - data = js.data.JaxSimModelData.build( + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( model=model_full, - base_position=jnp.array([0, 0, 0.8]), + key=subkey, velocity_representation=VelRepr.Inertial, + base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)), ) # ===== @@ -47,11 +47,65 @@ def test_model_creation_and_reduction( # Check that all non-fixed joints are in the full model. assert set(joint_names_in_description) == set(model_full.joint_names()) + # ================ + # Reduce the model + # ================ + + # Get the names of the joints to keep in the reduced model. + reduced_joints = tuple( + j + for j in model_full.joint_names() + if "camera" not in j + and "neck" not in j + and "wrist" not in j + and "thumb" not in j + and "index" not in j + and "middle" not in j + and "ring" not in j + and "pinkie" not in j + ) + + # Removed joints. + removed_joints = set(model_full.joint_names()) - set(reduced_joints) + + # Reduce the model. + # Note: here we also specify a non-zero position of the removed joints. + # The process should take into account the corresponding joint transforms + # when the link-joint-link chains are lumped together. + model_reduced = js.model.reduce( + model=model_full, + considered_joints=reduced_joints, + joint_positions_locked={ + name: pos + for name, pos in zip( + removed_joints, + data.joint_positions( + model=model_full, joint_names=tuple(removed_joints) + ).tolist(), + ) + }, + ) + + # Check DoFs. + assert model_full.dofs() != model_reduced.dofs() + + # Check that all non-fixed joints are in the reduced model. + assert set(reduced_joints) == set(model_reduced.joint_names()) + # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, - base_position=jnp.array([0, 0, 0.8]), - velocity_representation=VelRepr.Inertial, + base_position=data.base_position(), + base_quaternion=data.base_orientation(dcm=False), + joint_positions=data.joint_positions( + model=model_full, joint_names=model_reduced.joint_names() + ), + base_linear_velocity=data.base_velocity()[0:3], + base_angular_velocity=data.base_velocity()[3:6], + joint_velocities=data.joint_velocities( + model=model_full, joint_names=model_reduced.joint_names() + ), + velocity_representation=data.velocity_representation, ) # Check that the reduced model data is valid. @@ -68,6 +122,35 @@ def test_model_creation_and_reduction( js.com.com_position(model=model_reduced, data=data_reduced), abs=1e-6 ) + # Check that joint serialization works. + assert data.joint_positions( + model=model_full, joint_names=model_reduced.joint_names() + ) == pytest.approx(data_reduced.joint_positions()) + assert data.joint_velocities( + model=model_full, joint_names=model_reduced.joint_names() + ) == pytest.approx(data_reduced.joint_velocities()) + + # Check that link transforms are preserved. + for link_name in model_reduced.link_names(): + W_H_L_full = js.link.transform( + model=model_full, + data=data, + link_index=js.link.name_to_idx(model=model_full, link_name=link_name), + ) + W_H_L_reduced = js.link.transform( + model=model_reduced, + data=data_reduced, + link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name), + ) + assert W_H_L_full == pytest.approx(W_H_L_reduced) + + # Check that collidable point positions are preserved. + assert js.contact.collidable_point_positions( + model=model_full, data=data + ) == pytest.approx( + js.contact.collidable_point_positions(model=model_reduced, data=data_reduced) + ) + def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, From b968546010be7ca0008708c472244cd16e443e61 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 22 Apr 2024 16:30:20 +0200 Subject: [PATCH 10/17] Fix model reduction logic We only need to set the positions of the joints to remove. Passing the angles of the kept joints produces a wrong model with mismatching kinematics and collidable points. --- src/jaxsim/api/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index e74245139..f4e2fe54a 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -4,7 +4,7 @@ import dataclasses import functools import pathlib -from typing import Any +from typing import Any, Sequence import jax import jax.numpy as jnp @@ -56,7 +56,7 @@ def build_from_model_description( *, terrain: jaxsim.terrain.Terrain | None = None, is_urdf: bool | None = None, - considered_joints: list[str] | None = None, + considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. @@ -293,7 +293,7 @@ def reduce( # Update the initial position of the joints. # This is necessary to compute the correct pose of the link pairs connected # to removed joints. - for joint_name in model.joint_names(): + for joint_name in set(model.joint_names()) - set(considered_joints): j = intermediate_description.joints_dict[joint_name] with j.mutable_context(): j.initial_position = float(joint_positions_locked.get(joint_name, 0.0)) From 403e69c6d6c4ae083cb5843f136f11a05c9c80db Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 30 Apr 2024 10:07:17 +0200 Subject: [PATCH 11/17] Preserve frames of the original graph during the reduction process --- src/jaxsim/parsers/kinematic_graph.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9c3d1bb8b..040c6d332 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import dataclasses import functools @@ -267,16 +269,22 @@ def create_graph( list(set(removed_joints)), ) - def reduce(self, considered_joints: List[str]) -> "KinematicGraph": + def reduce(self, considered_joints: List[str]) -> KinematicGraph: """ - Reduce the kinematic graph by removing specified joints and lumping the mass and inertia of removed links into their parent links. + Reduce the kinematic graph by removing unspecified joints. + + When a joint is removed, the mass and inertia of its child link are lumped + with those of its parent link, obtaining a new link that combines the two. + The description of the removed joint specifies the default angle (usually 0) + that is considered when the joint is removed. Args: - considered_joints (List[str]): A list of joint names to consider. + considered_joints: A list of joint names to consider. Returns: - KinematicGraph: The reduced kinematic graph. + The reduced kinematic graph. """ + # The current object represents the complete kinematic graph full_graph = self @@ -430,7 +438,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph": reduced_graph = KinematicGraph( root=reduced_root_node, joints=reduced_joints, - frames=reduced_frames, + frames=self.frames + reduced_frames, root_pose=full_graph.root_pose, _joints_removed=( self._joints_removed From 62dfa694e6ad122ddbb0caff7d6e86a0726657c9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 30 Apr 2024 10:12:02 +0200 Subject: [PATCH 12/17] Update model reduction test --- tests/test_api_model.py | 96 ++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 4697973be..c6e229d68 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -20,7 +20,7 @@ def test_model_creation_and_reduction( model_full = jaxsim_model_ergocub key, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( + data_full = js.data.random_model_data( model=model_full, key=subkey, velocity_representation=VelRepr.Inertial, @@ -32,7 +32,7 @@ def test_model_creation_and_reduction( # ===== # Check that the data of the full model is valid. - assert data.valid(model=model_full) + assert data_full.valid(model=model_full) # Build the ROD model from the original description. assert isinstance(model_full.built_from, (str, pathlib.Path)) @@ -63,11 +63,13 @@ def test_model_creation_and_reduction( and "middle" not in j and "ring" not in j and "pinkie" not in j + # + and "elbow" not in j + and "shoulder" not in j + and "torso" not in j + and "r_knee" not in j ) - # Removed joints. - removed_joints = set(model_full.joint_names()) - set(reduced_joints) - # Reduce the model. # Note: here we also specify a non-zero position of the removed joints. # The process should take into account the corresponding joint transforms @@ -78,9 +80,9 @@ def test_model_creation_and_reduction( joint_positions_locked={ name: pos for name, pos in zip( - removed_joints, - data.joint_positions( - model=model_full, joint_names=tuple(removed_joints) + model_full.joint_names(), + data_full.joint_positions( + model=model_full, joint_names=model_full.joint_names() ).tolist(), ) }, @@ -95,62 +97,56 @@ def test_model_creation_and_reduction( # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, - base_position=data.base_position(), - base_quaternion=data.base_orientation(dcm=False), - joint_positions=data.joint_positions( + base_position=data_full.base_position(), + base_quaternion=data_full.base_orientation(dcm=False), + joint_positions=data_full.joint_positions( model=model_full, joint_names=model_reduced.joint_names() ), - base_linear_velocity=data.base_velocity()[0:3], - base_angular_velocity=data.base_velocity()[3:6], - joint_velocities=data.joint_velocities( + base_linear_velocity=data_full.base_velocity()[0:3], + base_angular_velocity=data_full.base_velocity()[3:6], + joint_velocities=data_full.joint_velocities( model=model_full, joint_names=model_reduced.joint_names() ), - velocity_representation=data.velocity_representation, + velocity_representation=data_full.velocity_representation, ) - # Check that the reduced model data is valid. - assert not data_reduced.valid(model=model_full) - assert data_reduced.valid(model=model_reduced) + # ===================== + # Test against iDynTree + # ===================== - # Check that the total mass is preserved. - assert js.model.total_mass(model=model_full) == pytest.approx( - js.model.total_mass(model=model_reduced) + kin_dyn_full = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model_full, data=data_full ) - # Check that the CoM position is preserved. - assert js.com.com_position(model=model_full, data=data) == pytest.approx( - js.com.com_position(model=model_reduced, data=data_reduced), abs=1e-6 + kin_dyn_reduced = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model_reduced, data=data_reduced ) - # Check that joint serialization works. - assert data.joint_positions( - model=model_full, joint_names=model_reduced.joint_names() - ) == pytest.approx(data_reduced.joint_positions()) - assert data.joint_velocities( - model=model_full, joint_names=model_reduced.joint_names() - ) == pytest.approx(data_reduced.joint_velocities()) - - # Check that link transforms are preserved. - for link_name in model_reduced.link_names(): - W_H_L_full = js.link.transform( - model=model_full, - data=data, - link_index=js.link.name_to_idx(model=model_full, link_name=link_name), - ) - W_H_L_reduced = js.link.transform( - model=model_reduced, - data=data_reduced, - link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name), - ) - assert W_H_L_full == pytest.approx(W_H_L_reduced) + # Check that the total mass is preserved. + assert kin_dyn_full.total_mass() == pytest.approx(kin_dyn_reduced.total_mass()) - # Check that collidable point positions are preserved. - assert js.contact.collidable_point_positions( - model=model_full, data=data - ) == pytest.approx( - js.contact.collidable_point_positions(model=model_reduced, data=data_reduced) + # Check that the CoM position match. + assert kin_dyn_full.com_position() == pytest.approx(kin_dyn_reduced.com_position()) + assert kin_dyn_full.com_position() == pytest.approx( + js.com.com_position(model=model_reduced, data=data_reduced) ) + # Check that link transforms match. + for link_name, link_idx in zip( + model_reduced.link_names(), + js.link.names_to_idxs( + model=model_reduced, link_names=model_reduced.link_names() + ), + ): + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( + kin_dyn_full.frame_transform(frame_name=link_name) + ) + assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx( + js.link.transform( + model=model_reduced, data=data_reduced, link_index=link_idx + ) + ) + def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, From 65f44c4f8d461f0a3416b64872c2a1299f8fec9f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 2 May 2024 08:59:26 +0200 Subject: [PATCH 13/17] Update iDynTree required version --- environment.yml | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 9b24a505b..cc6a4023b 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,7 @@ dependencies: - isort - pre-commit # [testing] - - idyntree + - idyntree >= 12.2.1 - pytest - pytest-icdiff - robot_descriptions diff --git a/setup.cfg b/setup.cfg index 33e19bd86..887110b8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,7 +73,7 @@ style = isort pre-commit testing = - idyntree + idyntree >= 12.2.1 pytest >=6.0 pytest-icdiff robot-descriptions From e6faf10a3ef86b17da965b1382f5e7f6cf23fa97 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 2 May 2024 09:02:14 +0200 Subject: [PATCH 14/17] Restore support to Python 3.10 --- src/jaxsim/parsers/descriptions/joint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index d56765826..26af59000 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -15,7 +15,6 @@ @enum.unique -@enum.verify(enum.CONTINUOUS) class JointType(enum.IntEnum): """ Type of supported joints. From cf184f27b0570a1183cbe4a5125bddc507612a38 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 2 May 2024 09:24:59 +0200 Subject: [PATCH 15/17] Update iDynTree wrapper to support reducing models with non-zero angles --- tests/utils_idyntree.py | 57 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 5344ad15b..7e8048841 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -12,7 +12,10 @@ def build_kindyncomputations_from_jaxsim_model( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + considered_joints: list[str] | None = None, + removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: """ Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`. @@ -20,6 +23,10 @@ def build_kindyncomputations_from_jaxsim_model( Args: model: The `JaxSimModel` from which to build the `KinDynComputations`. data: The `JaxSimModelData` from which to build the `KinDynComputations`. + considered_joints: + The list of joint names to consider in the `KinDynComputations`. + removed_joint_positions: + A dictionary defining the positions of the removed joints (default is 0). Returns: The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`. @@ -34,12 +41,41 @@ def build_kindyncomputations_from_jaxsim_model( ) or (isinstance(model.built_from, str) and " "KinDynComputations": + removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, + ) -> KinDynComputations: # Read the URDF description urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf @@ -99,11 +139,20 @@ def build( # Create the model loader mdl_loader = idt.ModelLoader() + # Handle removed_joint_positions if None + removed_joint_positions = ( + {str(name): float(pos) for name, pos in removed_joint_positions.items()} + if removed_joint_positions is not None + else dict() + ) + # Load the URDF description if not ( mdl_loader.loadModelFromString(urdf_string) if considered_joints is None - else mdl_loader.loadReducedModelFromString(urdf_string, considered_joints) + else mdl_loader.loadReducedModelFromString( + urdf_string, considered_joints, removed_joint_positions + ) ): raise RuntimeError("Failed to load URDF description") From e83e8e0ed741acf77c5a7fc8c1a834fbbf5f0a86 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 2 May 2024 09:25:30 +0200 Subject: [PATCH 16/17] Update iDynTree wrapper to provide frame quantities --- tests/utils_idyntree.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 7e8048841..7ec6ae1db 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -246,6 +246,13 @@ def link_names(self) -> list[str]: self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks()) ] + def frame_names(self) -> list[str]: + + return [ + self.kin_dyn.getFrameName(i) + for i in range(self.kin_dyn.getNrOfLinks(), self.kin_dyn.getNrOfFrames()) + ] + def joint_positions(self) -> npt.NDArray: vector = idt.VectorDynSize() @@ -321,6 +328,26 @@ def frame_transform(self, frame_name: str) -> npt.NDArray: return H + def frame_relative_transform( + self, ref_frame_name: str, frame_name: str + ) -> npt.NDArray: + + if self.kin_dyn.getFrameIndex(ref_frame_name) < 0: + raise ValueError(f"Frame '{ref_frame_name}' does not exist") + + if self.kin_dyn.getFrameIndex(frame_name) < 0: + raise ValueError(f"Frame '{frame_name}' does not exist") + + ref_H_frame: idt.Transform = self.kin_dyn.getRelativeTransform( + ref_frame_name, frame_name + ) + + H = np.eye(4) + H[0:3, 3] = ref_H_frame.getPosition().toNumPy() + H[0:3, 0:3] = ref_H_frame.getRotation().toNumPy() + + return H + def base_velocity(self) -> npt.NDArray: nu = idt.VectorDynSize() From 82a283bc3703d7ce27760d246b6664500c97ef81 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 2 May 2024 10:05:46 +0200 Subject: [PATCH 17/17] Rename argument specifying locked joint positions --- src/jaxsim/api/model.py | 18 +++++++++--------- tests/test_api_model.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f4e2fe54a..9d14a8fef 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -261,7 +261,7 @@ def link_names(self) -> tuple[str, ...]: def reduce( model: JaxSimModel, considered_joints: tuple[str, ...], - joint_positions_locked: dict[str, jtp.Float] | None = None, + locked_joint_positions: dict[str, jtp.Float] | None = None, ) -> JaxSimModel: """ Reduce the model by lumping together the links connected by removed joints. @@ -269,20 +269,20 @@ def reduce( Args: model: The model to reduce. considered_joints: The sequence of joints to consider. - joint_positions_locked: + locked_joint_positions: A dictionary containing the positions of the joints to be considered in the reduction process. The removed joints in the reduced model - will have their position locked to the value in this dictionary. - If a joint is not present in the dictionary, its position is set to zero. + will have their position locked to their value in this dictionary. + If a joint is not part of the dictionary, its position is set to zero. """ - joint_positions_locked = ( - joint_positions_locked if joint_positions_locked is not None else {} + locked_joint_positions = ( + locked_joint_positions if locked_joint_positions is not None else {} ) # If locked joints are passed, make sure that they are valid. - if not set(joint_positions_locked).issubset(model.joint_names()): - new_joints = set(model.joint_names()) - set(joint_positions_locked) + if not set(locked_joint_positions).issubset(model.joint_names()): + new_joints = set(model.joint_names()) - set(locked_joint_positions) raise ValueError(f"Passed joints not existing in the model: {new_joints}") # Copy the model description with a deep copy of the joints. @@ -296,7 +296,7 @@ def reduce( for joint_name in set(model.joint_names()) - set(considered_joints): j = intermediate_description.joints_dict[joint_name] with j.mutable_context(): - j.initial_position = float(joint_positions_locked.get(joint_name, 0.0)) + j.initial_position = float(locked_joint_positions.get(joint_name, 0.0)) # Reduce the model description. # If `considered_joints` contains joints not existing in the model, diff --git a/tests/test_api_model.py b/tests/test_api_model.py index c6e229d68..eba2c525f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -77,7 +77,7 @@ def test_model_creation_and_reduction( model_reduced = js.model.reduce( model=model_full, considered_joints=reduced_joints, - joint_positions_locked={ + locked_joint_positions={ name: pos for name, pos in zip( model_full.joint_names(),