Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read frames from the original model description #150

Merged
merged 3 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- pptree
- rod >= 0.2.0
- rod >= 0.3.0
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
Expand Down Expand Up @@ -41,18 +41,15 @@ dependencies:
- pip
- sphinx
- sphinx-autodoc-typehints
- sphinx-book-theme
- sphinx-copybutton
- sphinx-design
- sphinx_fontawesome
- sphinx-jinja2-compat
- sphinx-multiversion
- sphinx_rtd_theme
- sphinx-book-theme
- sphinx-toolbox
# ========================================
# Other dependencies for GitHub Codespaces
# ========================================
# System dependencies to run the tests
- gz-sim7
# Other packages
- ipython
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ install_requires =
jaxlie >= 1.3.0
jax_dataclasses >= 1.4.0
pptree
rod >= 0.2.0
rod >= 0.3.0
typing_extensions ; python_version < '3.12'

[options.packages.find]
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __post_init__(self):
# Also here, we assume the model is fixed-base, therefore the first frame will
# have last_link_idx + 1. These frames are not part of the physics model.
for index, frame in enumerate(self.frames):
frame.index = index + len(self.link_names())
with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
frame.index = int(index + len(self.link_names()))

# Number joints so that their index matches their child link index
links_dict = {l.name: l for l in iter(self)}
Expand Down
74 changes: 70 additions & 4 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import rod

import jaxsim.utils
from jaxsim import logging
from jaxsim.math.quaternion import Quaternion
from jaxsim.parsers import descriptions, kinematic_graph
Expand All @@ -25,6 +26,7 @@ class SDFData(NamedTuple):

link_descriptions: List[descriptions.LinkDescription]
joint_descriptions: List[descriptions.JointDescription]
frame_descriptions: List[descriptions.LinkDescription]
collision_shapes: List[descriptions.CollisionShape]

sdf_model: rod.Model | None = None
Expand Down Expand Up @@ -70,6 +72,8 @@ def extract_model_data(

# Jaxsim supports only models compatible with URDF, i.e. those having all links
# directly attached to their parent joint without additional roto-translations.
# Furthermore, the following switch also post-processes frames such that their
# pose is expressed wrt the parent link they are rigidly attached to.
sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf)

# Log type of base link
Expand Down Expand Up @@ -113,6 +117,23 @@ def extract_model_data(
# Create a dictionary to find easily links
links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links}

# ============
# Parse frames
# ============

# Parse the frames (unconnected)
frames = [
descriptions.LinkDescription(
name=f.name,
mass=jnp.array(0.0, dtype=float),
inertia=jnp.zeros(shape=(3, 3)),
parent=links_dict[f.attached_to],
pose=f.pose.transform() if f.pose is not None else jnp.eye(4),
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved
)
for f in sdf_model.frames()
if f.attached_to in links_dict
]

# =========================
# Process fixed-base models
# =========================
Expand Down Expand Up @@ -309,6 +330,7 @@ def extract_model_data(
model_name=sdf_model.name,
link_descriptions=links,
joint_descriptions=joints,
frame_descriptions=frames,
collision_shapes=collisions,
fixed_base=sdf_model.is_fixed_base(),
base_link_name=sdf_model.get_canonical_link(),
Expand Down Expand Up @@ -338,10 +360,14 @@ def build_model_description(
model_description=model_description, model_name=None, is_urdf=is_urdf
)

# Build the model description.
# Build the intermediate representation used for building a JaxSim model.
# This process, beyond other operations, removes the fixed joints.
# Note: if the model is fixed-base, the fixed joint between world and the first
# link is removed and the pose of the first link is updated.
model = descriptions.ModelDescription.build_model_from(
#
# The whole process is:
# URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel.
graph = descriptions.ModelDescription.build_model_from(
name=sdf_data.model_name,
links=sdf_data.link_descriptions,
joints=sdf_data.joint_descriptions,
Expand All @@ -356,7 +382,47 @@ def build_model_description(
],
)

# Depending on how the model is reduced due to the removal of fixed joints,
# there might be frames that are no longer attached to existing links.
# We need to change the link to which they are attached to, and update their pose.
frames_with_no_parent_link = (
f for f in sdf_data.frame_descriptions if f.parent.name not in graph
)

# Build the object to compute forward kinematics.
fk = kinematic_graph.KinematicGraphTransforms(graph=graph)

for frame in frames_with_no_parent_link:
# Get the original data of the frame.
original_pose = frame.pose
original_parent_link = frame.parent.name

# The parent link, that has been removed, became a frame.
assert original_parent_link in graph.frames_dict, (frame, original_parent_link)

# Get the new parent of the frame corresponding to the removed parent link.
new_parent_link = graph.frames_dict[original_parent_link].parent.name
logging.debug(f"Frame '{frame.name}' is now attached to '{new_parent_link}'")

# Get the transform from the new parent link to the original parent link.
# The original pose is expressed wrt the original parent link.
F_H_P = fk.relative_transform(
relative_to=new_parent_link, name=original_parent_link
)

# Update the frame with the updated data.
with frame.mutable_context(
mutability=jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION
):
frame.parent = graph.links_dict[new_parent_link]
frame.pose = np.array(F_H_P @ original_pose)

# Include the SDF frames originally stored in the SDF.
graph = dataclasses.replace(
graph, frames=sdf_data.frame_descriptions + graph.frames
)

# Store the parsed SDF tree as extra info
model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model})
graph = dataclasses.replace(graph, extra_info={"sdf_model": sdf_data.sdf_model})

return model
return graph