Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up passing different JaxSimModel with same pytree structure to JIT-compiled functions #179

Merged
merged 12 commits into from
Jun 14, 2024
6 changes: 4 additions & 2 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

def __hash__(self) -> int:

from jaxsim.utils.wrappers import HashedNumpyArray

return hash(
(
hash(self.state),
hash(tuple(self.gravity.flatten().tolist())),
HashedNumpyArray.hash_of_array(self.gravity),
hash(self.soft_contacts_params),
hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
HashedNumpyArray.hash_of_array(self.time_ns),
)
)

Expand Down
56 changes: 30 additions & 26 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
# =======================


def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int:
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
) -> jtp.Int:
"""
Get the index of the link to which the frame is rigidly attached.

Expand All @@ -29,17 +31,13 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
The index of the frame's parent link.
"""

# Get the intermediate representation parsed from the model description.
ir = model.description
return model.kin_dyn_parameters.frame_parameters.body[
frame_idx - model.number_of_links()
]

# Extract the indices of the frame and the link it is attached to.
F = ir.frames[frame_idx - model.number_of_links()]
L = ir.links_dict[F.parent.name].index

return int(L)


def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
@functools.partial(jax.jit, static_argnames="frame_name")
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
"""
Convert the name of a frame to its index.

Expand All @@ -51,13 +49,19 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
The index of the frame.
"""

frame_names = np.array([frame.name for frame in model.description.frames])
if frame_name in model.kin_dyn_parameters.frame_parameters.name:
return (
jnp.array(
np.argwhere(
np.array(model.kin_dyn_parameters.frame_parameters.name)
== frame_name
)
)
.squeeze()
.astype(int)
) + model.number_of_links()

if frame_name in frame_names:
idx_in_list = np.argwhere(frame_names == frame_name)
return int(idx_in_list.squeeze().tolist()) + model.number_of_links()

return -1
return jnp.array(-1).astype(int)


def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
Expand All @@ -72,7 +76,9 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
The name of the frame.
"""

return model.description.frames[frame_index - model.number_of_links()].name
return model.kin_dyn_parameters.frame_parameters.name[
frame_index - model.number_of_links()
]


@functools.partial(jax.jit, static_argnames=["frame_names"])
Expand All @@ -91,7 +97,7 @@ def names_to_idxs(
"""

return jnp.array(
[name_to_idx(model=model, frame_name=frame_name) for frame_name in frame_names]
[name_to_idx(model=model, frame_name=name) for name in frame_names]
).astype(int)


Expand All @@ -109,18 +115,15 @@ def idxs_to_names(
The names of the frames.
"""

return tuple(
idx_to_name(model=model, frame_index=frame_index)
for frame_index in frame_indices
)
return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)


# ==========
# Frame APIs
# ==========


@functools.partial(jax.jit, static_argnames=["frame_index"])
@jax.jit
def transform(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand All @@ -144,14 +147,15 @@ def transform(
W_H_L = js.link.transform(model=model, data=data, link_index=L)

# Get the static frame pose wrt the parent link.
frame = model.description.frames[frame_index - model.number_of_links()]
L_H_F = frame.pose
L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
frame_index - model.number_of_links()
]

# Combine the transforms computing the frame pose.
return W_H_L @ L_H_F


@functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"])
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down
97 changes: 90 additions & 7 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
import numpy as np
from jax_dataclasses import Static

import jaxsim.typing as jtp
Expand All @@ -15,7 +14,7 @@
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass


@jax_dataclasses.pytree_dataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class KynDynParameters(JaxsimDataclass):
r"""
Class storing the kinematic and dynamic parameters of a model.
Expand All @@ -26,6 +25,7 @@ class KynDynParameters(JaxsimDataclass):
support_body_array_bool:
The boolean support parent array :math:`\kappa_{b}(i)` of the model.
link_parameters: The parameters of the links.
frame_parameters: The parameters of the frames.
contact_parameters: The parameters of the collidable points.
joint_model: The joint model of the model.
joint_parameters: The parameters of the joints.
Expand All @@ -42,6 +42,9 @@ class KynDynParameters(JaxsimDataclass):
# Contacts
contact_parameters: ContactParameters

# Frames
frame_parameters: FrameParameters

# Joints
joint_model: JointModel
joint_parameters: JointParameters | None
Expand Down Expand Up @@ -141,6 +144,19 @@ def build(model_description: ModelDescription) -> KynDynParameters:
model_description=model_description
)

# =================
# Frames properties
# =================

# Create the object storing the parameters of frames.
# Note that, contrarily to LinkParameters and JointsParameters, this object
# is not created with vmap. This is because the "name" attribute of the object
# must be Static for JIT-related reasons, and tree_map would not consider it
# as a leaf.
frame_parameters = FrameParameters.build_from(
model_description=model_description
)

# ===============
# Tree properties
# ===============
Expand Down Expand Up @@ -206,6 +222,7 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
joint_model=joint_model,
joint_parameters=joint_parameters,
contact_parameters=contact_parameters,
frame_parameters=frame_parameters,
)

def __eq__(self, other: KynDynParameters) -> bool:
Expand All @@ -221,7 +238,8 @@ def __hash__(self) -> int:
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
hash(self.frame_parameters.name),
hash(tuple(self.frame_parameters.body.tolist())),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
Expand Down Expand Up @@ -730,7 +748,7 @@ class ContactParameters(JaxsimDataclass):
A tuple of integers representing, for each collidable point, the index of
the body (link) to which it is rigidly attached to.
point:
The translation between the link frame and the collidable point, expressed
The translations between the link frame and the collidable point, expressed
in the coordinates of the parent link frame.

Note:
Expand Down Expand Up @@ -773,10 +791,75 @@ def build_from(model_description: ModelDescription) -> ContactParameters:
links_dict[cp.parent_link.name].index for cp in collidable_points
)

# Build the GroundContact object.
# Build the ContactParameters object.
cp = ContactParameters(point=points, body=link_index_of_points) # noqa

assert cp.point.shape[1] == 3
assert cp.point.shape[0] == len(cp.body)
assert cp.point.shape[1] == 3, cp.point.shape[1]
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]

return cp


@jax_dataclasses.pytree_dataclass
class FrameParameters(JaxsimDataclass):
"""
Class storing the frame parameters of a model.

Attributes:
name: A tuple of strings defining the frame names.
body:
A vector of integers representing, for each frame, the index of
the body (link) to which it is rigidly attached to.
transform: The transforms of the frames w.r.t. their parent link.

Note:
Contrarily to LinkParameters and JointParameters, this class is not meant
to be created with vmap. This is because the `name` attribute must be `Static`.
"""

name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)

body: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))

transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))

@staticmethod
def build_from(model_description: ModelDescription) -> FrameParameters:
"""
Build a FrameParameters object from a model description.

Args:
model_description: The model description to consider.

Returns:
The FrameParameters object.
"""

if len(model_description.frames) == 0:
return FrameParameters()

# Extract the frame names.
names = tuple(frame.name for frame in model_description.frames)

# For each frame, extract the index of the link to which it is attached to.
parent_link_index_of_frames = tuple(
model_description.links_dict[frame.parent.name].index
for frame in model_description.frames
)

# For each frame, extract the transform w.r.t. its parent link.
transforms = jnp.atleast_3d(
jnp.stack([frame.pose for frame in model_description.frames])
)

# Build the FrameParameters object.
fp = FrameParameters(
name=names,
transform=transforms.astype(float),
body=jnp.array(parent_link_index_of_frames).astype(int),
)

assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]

return fp
39 changes: 26 additions & 13 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import jaxsim.parsers.descriptions
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.utils import JaxsimDataclass, Mutability
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers

from .common import VelRepr


@jax_dataclasses.pytree_dataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JaxSimModel(JaxsimDataclass):
"""
The JaxSim model defining the kinematics and dynamics of a robot.
Expand All @@ -31,34 +31,43 @@ class JaxSimModel(JaxsimDataclass):
model_name: Static[str]

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
default=jaxsim.terrain.FlatTerrain(), repr=False
)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
dataclasses.field(default=None, repr=False)
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False, compare=False, hash=False
default=None, repr=False
)

description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
_description: Static[
wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False)

@property
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:

if not isinstance(other, JaxSimModel):
return False

return hash(self) == hash(other)
if self.model_name != other.model_name:
return False

if self.kin_dyn_parameters != other.kin_dyn_parameters:
return False

return True

def __hash__(self) -> int:

return hash(
(
hash(self.model_name),
hash(self.description),
hash(self.kin_dyn_parameters),
)
)
Expand Down Expand Up @@ -152,10 +161,10 @@ def build(
# Set the model name (if not provided, use the one from the model description)
model_name = model_name if model_name is not None else model_description.name

# Build the model
# Build the model.
model = JaxSimModel(
model_name=model_name,
description=model_description,
_description=wrappers.HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
Expand Down Expand Up @@ -270,6 +279,10 @@ def link_names(self) -> tuple[str, ...]:

return self.kin_dyn_parameters.link_names

# =====================
# Frame-related methods
# =====================

def frame_names(self) -> tuple[str, ...]:
"""
Return the names of the links in the model.
Expand All @@ -278,7 +291,7 @@ def frame_names(self) -> tuple[str, ...]:
The names of the links in the model.
"""

return tuple(frame.name for frame in self.description.frames)
return self.kin_dyn_parameters.frame_parameters.name


# =====================
Expand Down
Loading