Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix processing of parsed frames #158

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 49 additions & 54 deletions src/jaxsim/parsers/descriptions/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import itertools
from typing import List
from typing import Sequence

from jaxsim import logging

Expand All @@ -13,63 +15,62 @@
@dataclasses.dataclass(frozen=True)
class ModelDescription(KinematicGraph):
"""
Description of a robotic model including links, joints, and collision shapes.

Args:
name (str): The name of the model.
fixed_base (bool): Indicates whether the model has a fixed base.
collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
Intermediate representation representing the kinematic graph of a robot model.

Attributes:
name (str): The name of the model.
fixed_base (bool): Indicates whether the model has a fixed base.
collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
name: The name of the model.
fixed_base: Whether the model is either fixed-base or floating-base.
collision_shapes: List of collision shapes associated with the model.
"""

name: str = None

fixed_base: bool = True
collision_shapes: List[CollisionShape] = dataclasses.field(default_factory=list)

collision_shapes: list[CollisionShape] = dataclasses.field(
default_factory=list, repr=False, hash=False
)

@staticmethod
def build_model_from(
name: str,
links: List[LinkDescription],
joints: List[JointDescription],
collisions: List[CollisionShape] = (),
links: list[LinkDescription],
joints: list[JointDescription],
frames: list[LinkDescription] | None = None,
collisions: list[CollisionShape] = (),
fixed_base: bool = False,
base_link_name: str | None = None,
considered_joints: List[str] | None = None,
considered_joints: Sequence[str] | None = None,
model_pose: RootPose = RootPose(),
) -> "ModelDescription":
) -> ModelDescription:
"""
Build a model description from provided components.

Args:
name (str): The name of the model.
links (List[LinkDescription]): List of link descriptions.
joints (List[JointDescription]): List of joint descriptions.
collisions (List[CollisionShape]): List of collision shapes associated with the model.
fixed_base (bool): Indicates whether the model has a fixed base.
base_link_name (str): Name of the base link.
considered_joints (List[str]): List of joint names to consider.
model_pose (RootPose): Pose of the model's root.
name: The name of the model.
links: List of link descriptions.
joints: List of joint descriptions.
frames: List of frame descriptions.
collisions: List of collision shapes associated with the model.
fixed_base: Indicates whether the model has a fixed base.
base_link_name: Name of the base link (i.e. the root of the kinematic tree).
considered_joints: List of joint names to consider (by default all joints).
model_pose: Pose of the model's root (by default an identity transform).

Returns:
ModelDescription: A ModelDescription instance representing the model.

Raises:
ValueError: If invalid or missing input data.
A ModelDescription instance representing the model.
"""

# Create the full kinematic graph
# Create the full kinematic graph.
kinematic_graph = KinematicGraph.build_from(
links=links,
joints=joints,
frames=frames,
root_link_name=base_link_name,
root_pose=model_pose,
)

# Reduce the graph if needed
# Reduce the graph if needed.
if considered_joints is not None:
kinematic_graph = kinematic_graph.reduce(
considered_joints=considered_joints
Expand All @@ -78,11 +79,13 @@ def build_model_from(
# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=kinematic_graph)

# Store here the final model collisions
final_collisions: List[CollisionShape] = []
# Container of the final model's collision shapes.
final_collisions: list[CollisionShape] = []

# Move and express the collision shapes of the removed link to the lumped link
# Move and express the collision shapes of removed links to the resulting
# lumped link that replace the combination of the removed link and its parent.
for collision_shape in collisions:

# Get all the collidable points of the shape
coll_points = list(collision_shape.collidable_points)

Expand Down Expand Up @@ -112,7 +115,7 @@ def build_model_from(
final_collisions.append(new_collision_shape)

# If the frame was found, update the collidable points' pose and add them
# to the new collision shape
# to the new collision shape
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved
for cp in collision_shape.collidable_points:
# Find the link that is part of the (reduced) model in which the
# collision shape's parent was lumped into
Expand Down Expand Up @@ -145,22 +148,20 @@ def build_model_from(
_joints_removed=kinematic_graph._joints_removed,
)

# Check that the root link of kinematic graph is the desired base link.
assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name

return model

def reduce(self, considered_joints: List[str]) -> "ModelDescription":
def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:
"""
Reduce the model by removing specified joints.

Args:
considered_joints (List[str]): List of joint names to consider.
The joint names to consider.

Returns:
ModelDescription: A reduced ModelDescription instance.

Raises:
ValueError: If the specified joints are not part of the model.
A `ModelDescription` instance that only includes the considered joints.
"""

if len(set(considered_joints) - set(self.joint_names())) != 0:
Expand All @@ -172,6 +173,7 @@ def reduce(self, considered_joints: List[str]) -> "ModelDescription":
name=self.name,
links=list(self.links_dict.values()),
joints=self.joints,
frames=self.frames,
collisions=self.collision_shapes,
fixed_base=self.fixed_base,
base_link_name=list(iter(self))[0].name,
Expand All @@ -190,12 +192,8 @@ def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:
Enable or disable collision shapes associated with a link.

Args:
link_name (str): Name of the link.
enabled (bool): Enable or disable collision shapes associated with the link.

Raises:
ValueError: If the link name is not found in the model.

link_name: The name of the link.
enabled: Enable or disable collision shapes associated with the link.
"""

if link_name not in self.link_names():
Expand All @@ -211,14 +209,10 @@ def collision_shape_of_link(self, link_name: str) -> CollisionShape:
Get the collision shape associated with a specific link.

Args:
link_name (str): Name of the link.
link_name: The name of the link.

Returns:
CollisionShape: The collision shape associated with the link.

Raises:
ValueError: If the link name is not found in the model.

The collision shape associated with the link.
"""

if link_name not in self.link_names():
Expand All @@ -233,14 +227,15 @@ def collision_shape_of_link(self, link_name: str) -> CollisionShape:
]
)

def all_enabled_collidable_points(self) -> List[CollidablePoint]:
def all_enabled_collidable_points(self) -> list[CollidablePoint]:
"""
Get all enabled collidable points in the model.

Returns:
List[CollidablePoint]: A list of all enabled collidable points.
The list of all enabled collidable points.

"""

# Get iterator of all collidable points
all_collidable_points = itertools.chain.from_iterable(
[shape.collidable_points for shape in self.collision_shapes]
Expand Down
Loading