Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose quantities related to generic frames #148

Merged
merged 30 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
036a29c
Create `jaxsim.api.frame` module with `transform` function
xela-95 May 9, 2024
eb5255f
Add `frame` module to `jaxsim.api` package
xela-95 May 9, 2024
a3f5f63
Add unit test for `jaxsim.api.frame` module
xela-95 May 9, 2024
f84df6c
Add index-related functions to `frame` module
xela-95 May 9, 2024
7fd622e
Add `test_frame_index` to `frame` unit tests
xela-95 May 9, 2024
9bcf0a1
Add `jacobian` method to `frame` module
xela-95 May 9, 2024
b345246
Add `test_frame_jacobians` to `frame` unit tests
xela-95 May 9, 2024
5192b00
Add `.vscode` to gitignore
xela-95 May 9, 2024
6d6614f
Add `frame` module to sphynx documentation
xela-95 May 9, 2024
9ba93a1
Apply suggestions from code review
xela-95 May 9, 2024
77c9e76
Update frame.py
xela-95 May 9, 2024
42de167
Add additional frame attached to link in box model
xela-95 May 9, 2024
33b6961
Refactor `test_frame_jacobians` to better debug jacobians not matchin…
xela-95 May 9, 2024
79518e0
Exclude from `test_frame_jacobians` the frames that are not loaded in…
xela-95 May 10, 2024
b751656
Add `frame_parent_link_name` method to `KinDynComputations class`
xela-95 May 10, 2024
037a626
Update code style of `frame.transform` function
xela-95 May 10, 2024
f17264f
WIP Update `test_frame_transforms` to print parent link frames and no…
xela-95 May 10, 2024
820b510
Update `test_frame_transforms`
xela-95 May 10, 2024
6788c09
Add single pendulum fixture in `conftest.py`
xela-95 May 10, 2024
2a5f859
Clean `test_api_frames`
xela-95 May 10, 2024
11ee5c6
Fix retrieval of the frame's parent link index
diegoferigo May 20, 2024
3d10a2e
Add function to get the frame's parent link index
diegoferigo May 20, 2024
0650d5f
Update frames test
diegoferigo May 21, 2024
7b47fe1
Align link and joint tests
diegoferigo May 20, 2024
e6cd3c0
Removed unused tested model
diegoferigo May 20, 2024
23fef9e
Update tests to use new ROD URDF exporter function
diegoferigo May 22, 2024
79b890f
Use plain integers for frame indices
diegoferigo May 21, 2024
a35482d
Add JaxSimModel.frame_names
diegoferigo May 21, 2024
a0ed4a8
Fix regression raising TracerBoolConversionError when comparing pytrees
diegoferigo May 22, 2024
8b5ac05
Temporarily disable test_pytree
diegoferigo May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# IDEs
.idea*
.vscode/

# Matlab
*.m~
Expand Down Expand Up @@ -141,4 +142,3 @@ src/jaxsim/_version.py
.ruff_cache/
# pixi environments
.pixi

5 changes: 5 additions & 0 deletions docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Link
.. automodule:: jaxsim.api.link
:members:

Frame
~~~~~
.. automodule:: jaxsim.api.frame
:members:

CoM
~~~
.. automodule:: jaxsim.api.com
Expand Down
12 changes: 11 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from . import common # isort:skip
from . import model, data # isort:skip
from . import com, contact, joint, kin_dyn_parameters, link, ode, ode_data, references
from . import (
com,
contact,
frame,
joint,
kin_dyn_parameters,
link,
ode,
ode_data,
references,
)
221 changes: 221 additions & 0 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import functools
from typing import Sequence

import jax
import jax.numpy as jnp
import jaxlie
import numpy as np

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp

from .common import VelRepr

# =======================
# Index-related functions
# =======================


def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int:
"""
Get the index of the link to which the frame is rigidly attached.

Args:
model: The model to consider.
frame_idx: The index of the frame.

Returns:
The index of the frame's parent link.
"""

# Get the intermediate representation parsed from the model description.
ir = model.description.get()

# Extract the indices of the frame and the link it is attached to.
F = ir.frames[frame_idx - model.number_of_links()]
L = ir.links_dict[F.parent.name].index

return int(L)


def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
"""
Convert the name of a frame to its index.

Args:
model: The model to consider.
frame_name: The name of the frame.

Returns:
The index of the frame.
"""

frame_names = np.array([frame.name for frame in model.description.get().frames])

if frame_name in frame_names:
idx_in_list = np.argwhere(frame_names == frame_name)
return int(idx_in_list.squeeze().tolist()) + model.number_of_links()

return -1


def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
"""
Convert the index of a frame to its name.

Args:
model: The model to consider.
frame_index: The index of the frame.

Returns:
The name of the frame.
"""

return model.description.get().frames[frame_index - model.number_of_links()].name


@functools.partial(jax.jit, static_argnames=["frame_names"])
def names_to_idxs(
model: js.model.JaxSimModel, *, frame_names: Sequence[str]
) -> jax.Array:
"""
Convert a sequence of frame names to their corresponding indices.

Args:
model: The model to consider.
frame_names: The names of the frames.

Returns:
The indices of the frames.
"""

return jnp.array(
[name_to_idx(model=model, frame_name=frame_name) for frame_name in frame_names]
).astype(int)


def idxs_to_names(
model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike]
) -> tuple[str, ...]:
"""
Convert a sequence of frame indices to their corresponding names.

Args:
model: The model to consider.
frame_indices: The indices of the frames.

Returns:
The names of the frames.
"""

return tuple(
idx_to_name(model=model, frame_index=frame_index)
for frame_index in frame_indices
)


# ==========
# Frame APIs
# ==========


@functools.partial(jax.jit, static_argnames=["frame_index"])
def transform(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
) -> jtp.Matrix:
"""
Compute the SE(3) transform from the world frame to the specified frame.

Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame for which the transform is requested.

Returns:
The 4x4 matrix representing the transform.
"""

# Compute the necessary transforms.
L = idx_of_parent_link(model=model, frame_idx=frame_index)
W_H_L = js.link.transform(model=model, data=data, link_index=L)

# Get the static frame pose wrt the parent link.
frame = model.description.get().frames[frame_index - model.number_of_links()]
L_H_F = frame.pose

# Combine the transforms computing the frame pose.
return W_H_L @ L_H_F


@functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the free-floating jacobian of the frame.

Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame.
output_vel_repr:
The output velocity representation of the free-floating jacobian.

Returns:
The 6×(6+n) free-floating jacobian of the frame.

Note:
The input representation of the free-floating jacobian is the active
velocity representation.
"""

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Get the index of the parent link.
L = idx_of_parent_link(model=model, frame_idx=frame_index)

# Compute the Jacobian of the parent link using body-fixed output representation.
L_J_WL = js.link.jacobian(
model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body
)

# Adjust the output representation
match output_vel_repr:
case VelRepr.Inertial:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
W_J_WL = W_X_L @ L_J_WL
O_J_WL_I = W_J_WL

case VelRepr.Body:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint()
F_J_WL = F_X_L @ L_J_WL
O_J_WL_I = F_J_WL

case VelRepr.Mixed:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
FW_H_L = FW_H_F @ F_H_L
FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint()
FW_J_WL = FW_X_L @ L_J_WL
O_J_WL_I = FW_J_WL

case _:
raise ValueError(output_vel_repr)

return O_J_WL_I
10 changes: 10 additions & 0 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ def link_names(self) -> tuple[str, ...]:

return self.kin_dyn_parameters.link_names

def frame_names(self) -> tuple[str, ...]:
"""
Return the names of the links in the model.

Returns:
The names of the links in the model.
"""

return tuple([frame.name for frame in self.description.get().frames])


# =====================
# Model post-processing
Expand Down
66 changes: 37 additions & 29 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
from typing import List

import jax.numpy as jnp
import jax_dataclasses
Expand All @@ -16,39 +17,46 @@ class LinkDescription(JaxsimDataclass):
In-memory description of a robot link.

Attributes:
name (str): The name of the link.
mass (float): The mass of the link.
inertia (jtp.Matrix): The inertia matrix of the link.
index (Optional[int]): An optional index for the link.
parent (Optional[LinkDescription]): The parent link of this link.
pose (jtp.Matrix): The pose transformation matrix of the link.
children (List[LinkDescription]): List of child links.
name: The name of the link.
mass: The mass of the link.
inertia: The inertia tensor of the link.
index: An optional index for the link (it gets automatically assigned).
parent: The parent link of this link.
pose: The pose transformation matrix of the link.
children: List of child links.
"""

name: Static[str]
mass: float
inertia: jtp.Matrix
mass: float = dataclasses.field(repr=False)
inertia: jtp.Matrix = dataclasses.field(repr=False)
index: int | None = None
parent: Static["LinkDescription"] = dataclasses.field(default=None, repr=False)
parent: LinkDescription = dataclasses.field(default=None, repr=False)
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
children: Static[List["LinkDescription"]] = dataclasses.field(

children: Static[list[LinkDescription]] = dataclasses.field(
default_factory=list, repr=False
)

def __hash__(self) -> int:
return hash(self.__repr__())

def __eq__(self, other) -> bool:
return (
self.name == other.name
and self.mass == other.mass
and (self.inertia == other.inertia).all()
and self.index == other.index
and self.parent == other.parent
and (self.pose == other.pose).all()
and self.children == other.children

return hash(
(
hash(self.name),
hash(float(self.mass)),
hash(tuple(self.inertia.flatten().tolist())),
hash(int(self.index)),
hash(self.parent),
hash(tuple(hash(c) for c in self.children)),
)
)

def __eq__(self, other: LinkDescription) -> bool:

if not isinstance(other, LinkDescription):
return False

return hash(self) == hash(other)

@property
def name_and_index(self) -> str:
"""
Expand All @@ -61,19 +69,19 @@ def name_and_index(self) -> str:
return f"#{self.index}_<{self.name}>"

def lump_with(
self, link: "LinkDescription", lumped_H_removed: jtp.Matrix
) -> "LinkDescription":
self, link: LinkDescription, lumped_H_removed: jtp.Matrix
) -> LinkDescription:
"""
Combine the current link with another link, preserving mass and inertia.

Args:
link (LinkDescription): The link to combine with.
lumped_H_removed (jtp.Matrix): The transformation matrix between the two links.
link: The link to combine with.
lumped_H_removed: The transformation matrix between the two links.

Returns:
LinkDescription: The combined link.

The combined link.
"""

# Get the 6D inertia of the link to remove
I_removed = link.inertia

Expand Down
Loading